supervisely 6.73.410__py3-none-any.whl → 6.73.470__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of supervisely might be problematic. Click here for more details.

Files changed (190) hide show
  1. supervisely/__init__.py +136 -1
  2. supervisely/_utils.py +81 -0
  3. supervisely/annotation/json_geometries_map.py +2 -0
  4. supervisely/annotation/label.py +80 -3
  5. supervisely/api/annotation_api.py +9 -9
  6. supervisely/api/api.py +67 -43
  7. supervisely/api/app_api.py +72 -5
  8. supervisely/api/dataset_api.py +108 -33
  9. supervisely/api/entity_annotation/figure_api.py +113 -49
  10. supervisely/api/image_api.py +82 -0
  11. supervisely/api/module_api.py +10 -0
  12. supervisely/api/nn/deploy_api.py +15 -9
  13. supervisely/api/nn/ecosystem_models_api.py +201 -0
  14. supervisely/api/nn/neural_network_api.py +12 -3
  15. supervisely/api/pointcloud/pointcloud_api.py +38 -0
  16. supervisely/api/pointcloud/pointcloud_episode_annotation_api.py +3 -0
  17. supervisely/api/project_api.py +213 -6
  18. supervisely/api/task_api.py +11 -1
  19. supervisely/api/video/video_annotation_api.py +4 -2
  20. supervisely/api/video/video_api.py +79 -1
  21. supervisely/api/video/video_figure_api.py +24 -11
  22. supervisely/api/volume/volume_api.py +38 -0
  23. supervisely/app/__init__.py +1 -1
  24. supervisely/app/content.py +14 -6
  25. supervisely/app/fastapi/__init__.py +1 -0
  26. supervisely/app/fastapi/custom_static_files.py +1 -1
  27. supervisely/app/fastapi/multi_user.py +88 -0
  28. supervisely/app/fastapi/subapp.py +175 -42
  29. supervisely/app/fastapi/templating.py +1 -1
  30. supervisely/app/fastapi/websocket.py +77 -9
  31. supervisely/app/singleton.py +21 -0
  32. supervisely/app/v1/app_service.py +18 -2
  33. supervisely/app/v1/constants.py +7 -1
  34. supervisely/app/widgets/__init__.py +11 -1
  35. supervisely/app/widgets/agent_selector/template.html +1 -0
  36. supervisely/app/widgets/card/card.py +20 -0
  37. supervisely/app/widgets/dataset_thumbnail/dataset_thumbnail.py +11 -2
  38. supervisely/app/widgets/dataset_thumbnail/template.html +3 -1
  39. supervisely/app/widgets/deploy_model/deploy_model.py +750 -0
  40. supervisely/app/widgets/dialog/dialog.py +12 -0
  41. supervisely/app/widgets/dialog/template.html +2 -1
  42. supervisely/app/widgets/dropdown_checkbox_selector/__init__.py +0 -0
  43. supervisely/app/widgets/dropdown_checkbox_selector/dropdown_checkbox_selector.py +87 -0
  44. supervisely/app/widgets/dropdown_checkbox_selector/template.html +12 -0
  45. supervisely/app/widgets/ecosystem_model_selector/__init__.py +0 -0
  46. supervisely/app/widgets/ecosystem_model_selector/ecosystem_model_selector.py +195 -0
  47. supervisely/app/widgets/experiment_selector/experiment_selector.py +454 -263
  48. supervisely/app/widgets/fast_table/fast_table.py +713 -126
  49. supervisely/app/widgets/fast_table/script.js +492 -95
  50. supervisely/app/widgets/fast_table/style.css +54 -0
  51. supervisely/app/widgets/fast_table/template.html +45 -5
  52. supervisely/app/widgets/heatmap/__init__.py +0 -0
  53. supervisely/app/widgets/heatmap/heatmap.py +523 -0
  54. supervisely/app/widgets/heatmap/script.js +378 -0
  55. supervisely/app/widgets/heatmap/style.css +227 -0
  56. supervisely/app/widgets/heatmap/template.html +21 -0
  57. supervisely/app/widgets/input_tag/input_tag.py +102 -15
  58. supervisely/app/widgets/input_tag_list/__init__.py +0 -0
  59. supervisely/app/widgets/input_tag_list/input_tag_list.py +274 -0
  60. supervisely/app/widgets/input_tag_list/template.html +70 -0
  61. supervisely/app/widgets/radio_table/radio_table.py +10 -2
  62. supervisely/app/widgets/radio_tabs/radio_tabs.py +18 -2
  63. supervisely/app/widgets/radio_tabs/template.html +1 -0
  64. supervisely/app/widgets/select/select.py +6 -4
  65. supervisely/app/widgets/select_dataset/select_dataset.py +6 -0
  66. supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +83 -7
  67. supervisely/app/widgets/table/table.py +68 -13
  68. supervisely/app/widgets/tabs/tabs.py +22 -6
  69. supervisely/app/widgets/tabs/template.html +5 -1
  70. supervisely/app/widgets/transfer/style.css +3 -0
  71. supervisely/app/widgets/transfer/template.html +3 -1
  72. supervisely/app/widgets/transfer/transfer.py +48 -45
  73. supervisely/app/widgets/tree_select/tree_select.py +2 -0
  74. supervisely/convert/image/csv/csv_converter.py +24 -15
  75. supervisely/convert/pointcloud/nuscenes_conv/nuscenes_converter.py +43 -41
  76. supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_converter.py +75 -51
  77. supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_helper.py +137 -124
  78. supervisely/convert/video/video_converter.py +2 -2
  79. supervisely/geometry/polyline_3d.py +110 -0
  80. supervisely/io/env.py +161 -1
  81. supervisely/nn/artifacts/__init__.py +1 -1
  82. supervisely/nn/artifacts/artifacts.py +10 -2
  83. supervisely/nn/artifacts/detectron2.py +1 -0
  84. supervisely/nn/artifacts/hrda.py +1 -0
  85. supervisely/nn/artifacts/mmclassification.py +20 -0
  86. supervisely/nn/artifacts/mmdetection.py +5 -3
  87. supervisely/nn/artifacts/mmsegmentation.py +1 -0
  88. supervisely/nn/artifacts/ritm.py +1 -0
  89. supervisely/nn/artifacts/rtdetr.py +1 -0
  90. supervisely/nn/artifacts/unet.py +1 -0
  91. supervisely/nn/artifacts/utils.py +3 -0
  92. supervisely/nn/artifacts/yolov5.py +2 -0
  93. supervisely/nn/artifacts/yolov8.py +1 -0
  94. supervisely/nn/benchmark/semantic_segmentation/metric_provider.py +18 -18
  95. supervisely/nn/experiments.py +9 -0
  96. supervisely/nn/inference/cache.py +37 -17
  97. supervisely/nn/inference/gui/serving_gui_template.py +39 -13
  98. supervisely/nn/inference/inference.py +953 -211
  99. supervisely/nn/inference/inference_request.py +15 -8
  100. supervisely/nn/inference/instance_segmentation/instance_segmentation.py +1 -0
  101. supervisely/nn/inference/object_detection/object_detection.py +1 -0
  102. supervisely/nn/inference/predict_app/__init__.py +0 -0
  103. supervisely/nn/inference/predict_app/gui/__init__.py +0 -0
  104. supervisely/nn/inference/predict_app/gui/classes_selector.py +160 -0
  105. supervisely/nn/inference/predict_app/gui/gui.py +915 -0
  106. supervisely/nn/inference/predict_app/gui/input_selector.py +344 -0
  107. supervisely/nn/inference/predict_app/gui/model_selector.py +77 -0
  108. supervisely/nn/inference/predict_app/gui/output_selector.py +179 -0
  109. supervisely/nn/inference/predict_app/gui/preview.py +93 -0
  110. supervisely/nn/inference/predict_app/gui/settings_selector.py +881 -0
  111. supervisely/nn/inference/predict_app/gui/tags_selector.py +110 -0
  112. supervisely/nn/inference/predict_app/gui/utils.py +399 -0
  113. supervisely/nn/inference/predict_app/predict_app.py +176 -0
  114. supervisely/nn/inference/session.py +47 -39
  115. supervisely/nn/inference/tracking/bbox_tracking.py +5 -1
  116. supervisely/nn/inference/tracking/point_tracking.py +5 -1
  117. supervisely/nn/inference/tracking/tracker_interface.py +4 -0
  118. supervisely/nn/inference/uploader.py +9 -5
  119. supervisely/nn/model/model_api.py +44 -22
  120. supervisely/nn/model/prediction.py +15 -1
  121. supervisely/nn/model/prediction_session.py +70 -14
  122. supervisely/nn/prediction_dto.py +7 -0
  123. supervisely/nn/tracker/__init__.py +6 -8
  124. supervisely/nn/tracker/base_tracker.py +54 -0
  125. supervisely/nn/tracker/botsort/__init__.py +1 -0
  126. supervisely/nn/tracker/botsort/botsort_config.yaml +30 -0
  127. supervisely/nn/tracker/botsort/osnet_reid/__init__.py +0 -0
  128. supervisely/nn/tracker/botsort/osnet_reid/osnet.py +566 -0
  129. supervisely/nn/tracker/botsort/osnet_reid/osnet_reid_interface.py +88 -0
  130. supervisely/nn/tracker/botsort/tracker/__init__.py +0 -0
  131. supervisely/nn/tracker/{bot_sort → botsort/tracker}/basetrack.py +1 -2
  132. supervisely/nn/tracker/{utils → botsort/tracker}/gmc.py +51 -59
  133. supervisely/nn/tracker/{deep_sort/deep_sort → botsort/tracker}/kalman_filter.py +71 -33
  134. supervisely/nn/tracker/botsort/tracker/matching.py +202 -0
  135. supervisely/nn/tracker/{bot_sort/bot_sort.py → botsort/tracker/mc_bot_sort.py} +68 -81
  136. supervisely/nn/tracker/botsort_tracker.py +273 -0
  137. supervisely/nn/tracker/calculate_metrics.py +264 -0
  138. supervisely/nn/tracker/utils.py +273 -0
  139. supervisely/nn/tracker/visualize.py +520 -0
  140. supervisely/nn/training/gui/gui.py +152 -49
  141. supervisely/nn/training/gui/hyperparameters_selector.py +1 -1
  142. supervisely/nn/training/gui/model_selector.py +8 -6
  143. supervisely/nn/training/gui/train_val_splits_selector.py +144 -71
  144. supervisely/nn/training/gui/training_artifacts.py +3 -1
  145. supervisely/nn/training/train_app.py +225 -46
  146. supervisely/project/pointcloud_episode_project.py +12 -8
  147. supervisely/project/pointcloud_project.py +12 -8
  148. supervisely/project/project.py +221 -75
  149. supervisely/template/experiment/experiment.html.jinja +105 -55
  150. supervisely/template/experiment/experiment_generator.py +258 -112
  151. supervisely/template/experiment/header.html.jinja +31 -13
  152. supervisely/template/experiment/sly-style.css +7 -2
  153. supervisely/versions.json +3 -1
  154. supervisely/video/sampling.py +42 -20
  155. supervisely/video/video.py +41 -12
  156. supervisely/video_annotation/video_figure.py +38 -4
  157. supervisely/volume/stl_converter.py +2 -0
  158. supervisely/worker_api/agent_rpc.py +24 -1
  159. supervisely/worker_api/rpc_servicer.py +31 -7
  160. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/METADATA +22 -14
  161. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/RECORD +167 -148
  162. supervisely_lib/__init__.py +6 -1
  163. supervisely/app/widgets/experiment_selector/style.css +0 -27
  164. supervisely/app/widgets/experiment_selector/template.html +0 -61
  165. supervisely/nn/tracker/bot_sort/__init__.py +0 -21
  166. supervisely/nn/tracker/bot_sort/fast_reid_interface.py +0 -152
  167. supervisely/nn/tracker/bot_sort/matching.py +0 -127
  168. supervisely/nn/tracker/bot_sort/sly_tracker.py +0 -401
  169. supervisely/nn/tracker/deep_sort/__init__.py +0 -6
  170. supervisely/nn/tracker/deep_sort/deep_sort/__init__.py +0 -1
  171. supervisely/nn/tracker/deep_sort/deep_sort/detection.py +0 -49
  172. supervisely/nn/tracker/deep_sort/deep_sort/iou_matching.py +0 -81
  173. supervisely/nn/tracker/deep_sort/deep_sort/linear_assignment.py +0 -202
  174. supervisely/nn/tracker/deep_sort/deep_sort/nn_matching.py +0 -176
  175. supervisely/nn/tracker/deep_sort/deep_sort/track.py +0 -166
  176. supervisely/nn/tracker/deep_sort/deep_sort/tracker.py +0 -145
  177. supervisely/nn/tracker/deep_sort/deep_sort.py +0 -301
  178. supervisely/nn/tracker/deep_sort/generate_clip_detections.py +0 -90
  179. supervisely/nn/tracker/deep_sort/preprocessing.py +0 -70
  180. supervisely/nn/tracker/deep_sort/sly_tracker.py +0 -273
  181. supervisely/nn/tracker/tracker.py +0 -285
  182. supervisely/nn/tracker/utils/kalman_filter.py +0 -492
  183. supervisely/nn/tracking/__init__.py +0 -1
  184. supervisely/nn/tracking/boxmot.py +0 -114
  185. supervisely/nn/tracking/tracking.py +0 -24
  186. /supervisely/{nn/tracker/utils → app/widgets/deploy_model}/__init__.py +0 -0
  187. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/LICENSE +0 -0
  188. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/WHEEL +0 -0
  189. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/entry_points.txt +0 -0
  190. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,11 @@
1
1
  import sys
2
+
2
3
  import supervisely
3
4
  from supervisely import *
4
5
 
5
6
  sys.modules['supervisely_lib'] = supervisely
6
- sys.modules['supervisely_lib.api.api'] = supervisely.api
7
+
8
+ for module_name in list(sys.modules.keys()):
9
+ if module_name.startswith("supervisely."):
10
+ new_name = module_name.replace("supervisely.", "supervisely_lib.", 1)
11
+ sys.modules[new_name] = sys.modules[module_name]
@@ -1,27 +0,0 @@
1
- .custom-models-selector-table {
2
- border-collapse: collapse;
3
- }
4
- .custom-models-selector-table tr:nth-child(2n) {
5
- background-color: #f6f8fa;
6
- }
7
- .custom-models-selector-table td,
8
- .custom-models-selector-table th {
9
- border: 1px solid #dfe2e5;
10
- padding: 6px 13px;
11
- text-align: center;
12
- line-height: 20px;
13
- }
14
-
15
- .custom-models-selector-table td {
16
- text-align: left;
17
- }
18
-
19
- .custom-models-selector-table tr td:nth-child(4) {
20
- text-align: center;
21
- }
22
-
23
- .el-radio-group.multi-line label.el-radio {
24
- display: block;
25
- margin-left: 0px;
26
- margin-bottom: 5px;
27
- }
@@ -1,61 +0,0 @@
1
- <link rel="stylesheet" href="./sly/css/app/widgets/custom_models_selector/style.css" />
2
-
3
- <div {% if widget._changes_handled==true %} @change="post('/{{{widget.widget_id}}}/value_changed')" {% endif %}>
4
-
5
- <div v-if="Object.keys(data.{{{widget.widget_id}}}.rowsHtml).length === 0"> You don't have any custom models</div>
6
- <div v-else>
7
-
8
- <div v-if="data.{{{widget.widget_id}}}.taskTypes.length > 1">
9
- <sly-field title="Task Type">
10
- <el-radio-group class="multi-line mt10" :value="state.{{{widget.widget_id}}}.selectedTaskType" {% if
11
- widget._task_type_changes_handled==true %}
12
- @input="(evt) => {state.{{{widget.widget_id}}}.selectedTaskType = evt; state.{{{widget.widget_id}}}.selectedRow = 0; post('/{{{widget.widget_id}}}/task_type_changed')}"
13
- {% else %}
14
- @input="(evt) => {state.{{{widget.widget_id}}}.selectedTaskType = evt; state.{{{widget.widget_id}}}.selectedRow = 0;}"
15
- {% endif %}>
16
-
17
- <el-radio v-for="(item, idx) in {{{widget._task_types}}}" :key="item" :label="item">
18
- {{ item }}
19
- </el-radio>
20
- </el-radio-group>
21
- </sly-field>
22
- </div>
23
-
24
- <div>
25
-
26
- <table class="custom-models-selector-table">
27
- <thead>
28
- <tr>
29
- <th v-for="col in data.{{{widget.widget_id}}}.columns">
30
- <div> {{col}} </div>
31
- </th>
32
- </tr>
33
- </thead>
34
- <tbody>
35
- <tr
36
- v-for="row, ridx in data.{{{widget.widget_id}}}.rowsHtml[state.{{{widget.widget_id}}}.selectedTaskType]">
37
- <td v-for="col, vidx in row">
38
- <div v-if="vidx === 0" style="display: flex;">
39
- <el-radio style="display: flex;" v-model="state.{{{widget.widget_id}}}.selectedRow"
40
- :label="ridx">&#8205;</el-radio>
41
-
42
- <sly-html-compiler :params="{ridx: ridx, vidx: vidx}" :template="col" :data="data"
43
- :state="state"></sly-html-compiler>
44
-
45
- </div>
46
-
47
- <div v-else>
48
-
49
- <sly-html-compiler :params="{ridx: ridx, vidx: vidx}" :template="col" :data="data"
50
- :state="state">
51
- </sly-html-compiler>
52
-
53
- </div>
54
-
55
- </td>
56
- </tr>
57
- </tbody>
58
- </table>
59
- </div>
60
- </div>
61
- </div>
@@ -1,21 +0,0 @@
1
- import torch # pylint: disable=import-error
2
-
3
- try:
4
- import fastreid
5
- except ImportError:
6
- import sys
7
- from pathlib import Path
8
-
9
- fast_reid_repo_url = "https://github.com/supervisely-ecosystem/fast-reid.git"
10
- fast_reid_parent_path = Path(__file__).parent
11
- fast_reid_path = fast_reid_parent_path.joinpath("fast_reid")
12
- if not fast_reid_path.exists():
13
- import subprocess
14
-
15
- subprocess.run(["git", "clone", fast_reid_repo_url, str(fast_reid_path.resolve())])
16
-
17
- sys.path.insert(0, str(fast_reid_path.resolve()))
18
-
19
- import fastreid
20
-
21
- from supervisely.nn.tracker.bot_sort.sly_tracker import BoTTracker
@@ -1,152 +0,0 @@
1
- import cv2
2
- import matplotlib.pyplot as plt
3
- import numpy as np
4
- import torch # pylint: disable=import-error
5
- import torch.nn.functional as F # pylint: disable=import-error
6
- from fastreid.config import get_cfg # pylint: disable=import-error
7
- from fastreid.modeling.meta_arch import build_model # pylint: disable=import-error
8
- from fastreid.utils.checkpoint import Checkpointer # pylint: disable=import-error
9
-
10
- # from torch.backends import cudnn
11
-
12
-
13
- # cudnn.benchmark = True
14
-
15
-
16
- def setup_cfg(config_file, opts):
17
- # load config from file and command-line arguments
18
- cfg = get_cfg()
19
- cfg.merge_from_file(config_file)
20
- cfg.merge_from_list(opts)
21
- cfg.MODEL.BACKBONE.PRETRAIN = False
22
-
23
- cfg.freeze()
24
-
25
- return cfg
26
-
27
-
28
- def postprocess(features):
29
- # Normalize feature to compute cosine distance
30
- features = F.normalize(features)
31
- features = features.cpu().data.numpy()
32
- return features
33
-
34
-
35
- def preprocess(image, input_size):
36
- if len(image.shape) == 3:
37
- padded_img = np.ones((input_size[1], input_size[0], 3), dtype=np.uint8) * 114
38
- else:
39
- padded_img = np.ones(input_size) * 114
40
- img = np.array(image)
41
- r = min(input_size[1] / img.shape[0], input_size[0] / img.shape[1])
42
- resized_img = cv2.resize(
43
- img,
44
- (int(img.shape[1] * r), int(img.shape[0] * r)),
45
- interpolation=cv2.INTER_LINEAR,
46
- )
47
- padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
48
-
49
- return padded_img, r
50
-
51
-
52
- class FastReIDInterface:
53
- def __init__(self, config_file, weights_path, device, batch_size=8):
54
- super(FastReIDInterface, self).__init__()
55
- if device != "cpu":
56
- self.device = "cuda"
57
- else:
58
- self.device = "cpu"
59
-
60
- self.batch_size = batch_size
61
-
62
- self.cfg = setup_cfg(config_file, ["MODEL.WEIGHTS", weights_path])
63
-
64
- self.model = build_model(self.cfg)
65
- self.model.eval()
66
-
67
- Checkpointer(self.model).load(weights_path)
68
-
69
- if self.device != "cpu":
70
- self.model = self.model.eval().to(device="cuda").half()
71
- else:
72
- self.model = self.model.eval()
73
-
74
- self.pH, self.pW = self.cfg.INPUT.SIZE_TEST
75
-
76
- def inference(self, image, detections):
77
-
78
- if detections is None or np.size(detections) == 0:
79
- return []
80
-
81
- H, W, _ = np.shape(image)
82
-
83
- batch_patches = []
84
- patches = []
85
- for d in range(np.size(detections, 0)):
86
- tlbr = detections[d, :4].astype(np.int_)
87
- tlbr[0] = max(0, tlbr[0])
88
- tlbr[1] = max(0, tlbr[1])
89
- tlbr[2] = min(W - 1, tlbr[2])
90
- tlbr[3] = min(H - 1, tlbr[3])
91
- patch = image[tlbr[1] : tlbr[3], tlbr[0] : tlbr[2], :]
92
-
93
- # the model expects RGB inputs
94
- patch = patch[:, :, ::-1]
95
-
96
- # Apply pre-processing to image.
97
- patch = cv2.resize(
98
- patch, tuple(self.cfg.INPUT.SIZE_TEST[::-1]), interpolation=cv2.INTER_LINEAR
99
- )
100
- # patch, scale = preprocess(patch, self.cfg.INPUT.SIZE_TEST[::-1])
101
-
102
- # plt.figure()
103
- # plt.imshow(patch)
104
- # plt.show()
105
-
106
- # Make shape with a new batch dimension which is adapted for network input
107
- patch = torch.as_tensor(patch.astype("float32").transpose(2, 0, 1))
108
- patch = patch.to(device=self.device).half()
109
-
110
- patches.append(patch)
111
-
112
- if (d + 1) % self.batch_size == 0:
113
- patches = torch.stack(patches, dim=0)
114
- batch_patches.append(patches)
115
- patches = []
116
-
117
- if len(patches):
118
- patches = torch.stack(patches, dim=0)
119
- batch_patches.append(patches)
120
-
121
- features = np.zeros((0, 2048))
122
- # features = np.zeros((0, 768))
123
-
124
- for patches in batch_patches:
125
-
126
- # Run model
127
- patches_ = torch.clone(patches)
128
- pred = self.model(patches)
129
- pred[torch.isinf(pred)] = 1.0
130
-
131
- feat = postprocess(pred)
132
-
133
- nans = np.isnan(np.sum(feat, axis=1))
134
- if np.isnan(feat).any():
135
- for n in range(np.size(nans)):
136
- if nans[n]:
137
- # patch_np = patches[n, ...].squeeze().transpose(1, 2, 0).cpu().numpy()
138
- patch_np = patches_[n, ...]
139
- patch_np_ = torch.unsqueeze(patch_np, 0)
140
- pred_ = self.model(patch_np_)
141
-
142
- patch_np = torch.squeeze(patch_np).cpu()
143
- patch_np = torch.permute(patch_np, (1, 2, 0)).int()
144
- patch_np = patch_np.numpy()
145
-
146
- plt.figure()
147
- plt.imshow(patch_np)
148
- plt.show()
149
-
150
- features = np.vstack((features, feat))
151
-
152
- return features
@@ -1,127 +0,0 @@
1
- import lap # pylint: disable=import-error
2
- import numpy as np
3
- from cython_bbox import bbox_overlaps as bbox_ious # pylint: disable=import-error
4
- from scipy.spatial.distance import cdist
5
-
6
- from supervisely.nn.tracker.utils import kalman_filter
7
-
8
-
9
- def linear_assignment(cost_matrix, thresh):
10
- if cost_matrix.size == 0:
11
- return (
12
- np.empty((0, 2), dtype=int),
13
- tuple(range(cost_matrix.shape[0])),
14
- tuple(range(cost_matrix.shape[1])),
15
- )
16
- matches, unmatched_a, unmatched_b = [], [], []
17
- cost, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)
18
- for ix, mx in enumerate(x):
19
- if mx >= 0:
20
- matches.append([ix, mx])
21
- unmatched_a = np.where(x < 0)[0]
22
- unmatched_b = np.where(y < 0)[0]
23
- matches = np.asarray(matches)
24
- return matches, unmatched_a, unmatched_b
25
-
26
-
27
- def ious(atlbrs, btlbrs):
28
- """
29
- Compute cost based on IoU
30
- :type atlbrs: list[tlbr] | np.ndarray
31
- :type atlbrs: list[tlbr] | np.ndarray
32
-
33
- :rtype ious np.ndarray
34
- """
35
- ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=float)
36
- if ious.size == 0:
37
- return ious
38
-
39
- ious = bbox_ious(
40
- np.ascontiguousarray(atlbrs, dtype=float), np.ascontiguousarray(btlbrs, dtype=float)
41
- )
42
-
43
- return ious
44
-
45
-
46
- def iou_distance(atracks, btracks):
47
- """
48
- Compute cost based on IoU
49
- :type atracks: list[STrack]
50
- :type btracks: list[STrack]
51
-
52
- :rtype cost_matrix np.ndarray
53
- """
54
-
55
- if (len(atracks) > 0 and isinstance(atracks[0], np.ndarray)) or (
56
- len(btracks) > 0 and isinstance(btracks[0], np.ndarray)
57
- ):
58
- atlbrs = atracks
59
- btlbrs = btracks
60
- else:
61
- atlbrs = [track.tlbr for track in atracks]
62
- btlbrs = [track.tlbr for track in btracks]
63
- _ious = ious(atlbrs, btlbrs)
64
- cost_matrix = 1 - _ious
65
-
66
- return cost_matrix
67
-
68
-
69
- def embedding_distance(tracks, detections, metric="cosine"):
70
- """
71
- :param tracks: list[STrack]
72
- :param detections: list[BaseTrack]
73
- :param metric:
74
- :return: cost_matrix np.ndarray
75
- """
76
-
77
- cost_matrix = np.zeros((len(tracks), len(detections)), dtype=float)
78
- if cost_matrix.size == 0:
79
- return cost_matrix
80
- det_features = np.asarray([track.curr_feat for track in detections], dtype=float)
81
- track_features = np.asarray([track.smooth_feat for track in tracks], dtype=float)
82
-
83
- cost_matrix = np.maximum(
84
- 0.0, cdist(track_features, det_features, metric)
85
- ) # / 2.0 # Nomalized features
86
- return cost_matrix
87
-
88
-
89
- def fuse_motion(kf, cost_matrix, tracks, detections, only_position=False, lambda_=0.98):
90
- if cost_matrix.size == 0:
91
- return cost_matrix
92
- gating_dim = 2 if only_position else 4
93
- gating_threshold = kalman_filter.chi2inv95[gating_dim]
94
- # measurements = np.asarray([det.to_xyah() for det in detections])
95
- measurements = np.asarray([det.to_xywh() for det in detections])
96
- for row, track in enumerate(tracks):
97
- gating_distance = kf.gating_distance(
98
- track.mean, track.covariance, measurements, only_position, metric="maha"
99
- )
100
- cost_matrix[row, gating_distance > gating_threshold] = np.inf
101
- cost_matrix[row] = lambda_ * cost_matrix[row] + (1 - lambda_) * gating_distance
102
- return cost_matrix
103
-
104
-
105
- def fuse_iou(cost_matrix, tracks, detections):
106
- if cost_matrix.size == 0:
107
- return cost_matrix
108
- reid_sim = 1 - cost_matrix
109
- iou_dist = iou_distance(tracks, detections)
110
- iou_sim = 1 - iou_dist
111
- fuse_sim = reid_sim * (1 + iou_sim) / 2
112
- det_scores = np.array([det.score for det in detections])
113
- det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0)
114
- # fuse_sim = fuse_sim * (1 + det_scores) / 2
115
- fuse_cost = 1 - fuse_sim
116
- return fuse_cost
117
-
118
-
119
- def fuse_score(cost_matrix, detections):
120
- if cost_matrix.size == 0:
121
- return cost_matrix
122
- iou_sim = 1 - cost_matrix
123
- det_scores = np.array([det.score for det in detections])
124
- det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0)
125
- fuse_sim = iou_sim * det_scores
126
- fuse_cost = 1 - fuse_sim
127
- return fuse_cost