supervisely 6.73.417__py3-none-any.whl → 6.73.419__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supervisely/api/entity_annotation/figure_api.py +89 -45
- supervisely/nn/inference/inference.py +61 -45
- supervisely/nn/inference/instance_segmentation/instance_segmentation.py +1 -0
- supervisely/nn/inference/object_detection/object_detection.py +1 -0
- supervisely/nn/inference/session.py +4 -4
- supervisely/nn/model/model_api.py +31 -20
- supervisely/nn/model/prediction.py +11 -0
- supervisely/nn/model/prediction_session.py +33 -6
- supervisely/nn/tracker/__init__.py +1 -2
- supervisely/nn/tracker/base_tracker.py +44 -0
- supervisely/nn/tracker/botsort/__init__.py +1 -0
- supervisely/nn/tracker/botsort/botsort_config.yaml +31 -0
- supervisely/nn/tracker/botsort/osnet_reid/osnet.py +566 -0
- supervisely/nn/tracker/botsort/osnet_reid/osnet_reid_interface.py +88 -0
- supervisely/nn/tracker/botsort/tracker/__init__.py +0 -0
- supervisely/nn/tracker/{bot_sort → botsort/tracker}/basetrack.py +1 -2
- supervisely/nn/tracker/{utils → botsort/tracker}/gmc.py +51 -59
- supervisely/nn/tracker/{deep_sort/deep_sort → botsort/tracker}/kalman_filter.py +71 -33
- supervisely/nn/tracker/botsort/tracker/matching.py +202 -0
- supervisely/nn/tracker/{bot_sort/bot_sort.py → botsort/tracker/mc_bot_sort.py} +68 -81
- supervisely/nn/tracker/botsort_tracker.py +259 -0
- supervisely/project/project.py +1 -1
- {supervisely-6.73.417.dist-info → supervisely-6.73.419.dist-info}/METADATA +5 -3
- {supervisely-6.73.417.dist-info → supervisely-6.73.419.dist-info}/RECORD +29 -42
- supervisely/nn/tracker/bot_sort/__init__.py +0 -21
- supervisely/nn/tracker/bot_sort/fast_reid_interface.py +0 -152
- supervisely/nn/tracker/bot_sort/matching.py +0 -127
- supervisely/nn/tracker/bot_sort/sly_tracker.py +0 -401
- supervisely/nn/tracker/deep_sort/__init__.py +0 -6
- supervisely/nn/tracker/deep_sort/deep_sort/__init__.py +0 -1
- supervisely/nn/tracker/deep_sort/deep_sort/detection.py +0 -49
- supervisely/nn/tracker/deep_sort/deep_sort/iou_matching.py +0 -81
- supervisely/nn/tracker/deep_sort/deep_sort/linear_assignment.py +0 -202
- supervisely/nn/tracker/deep_sort/deep_sort/nn_matching.py +0 -176
- supervisely/nn/tracker/deep_sort/deep_sort/track.py +0 -166
- supervisely/nn/tracker/deep_sort/deep_sort/tracker.py +0 -145
- supervisely/nn/tracker/deep_sort/deep_sort.py +0 -301
- supervisely/nn/tracker/deep_sort/generate_clip_detections.py +0 -90
- supervisely/nn/tracker/deep_sort/preprocessing.py +0 -70
- supervisely/nn/tracker/deep_sort/sly_tracker.py +0 -273
- supervisely/nn/tracker/tracker.py +0 -285
- supervisely/nn/tracker/utils/kalman_filter.py +0 -492
- supervisely/nn/tracking/__init__.py +0 -1
- supervisely/nn/tracking/boxmot.py +0 -114
- supervisely/nn/tracking/tracking.py +0 -24
- /supervisely/nn/tracker/{utils → botsort/osnet_reid}/__init__.py +0 -0
- {supervisely-6.73.417.dist-info → supervisely-6.73.419.dist-info}/LICENSE +0 -0
- {supervisely-6.73.417.dist-info → supervisely-6.73.419.dist-info}/WHEEL +0 -0
- {supervisely-6.73.417.dist-info → supervisely-6.73.419.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.417.dist-info → supervisely-6.73.419.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
import cv2
|
|
3
|
+
import numpy as np
|
|
4
|
+
from .osnet import osnet_x1_0
|
|
5
|
+
from collections import OrderedDict
|
|
6
|
+
from supervisely import logger
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
# pylint: disable=import-error
|
|
10
|
+
import torch
|
|
11
|
+
from torch.nn import functional as F
|
|
12
|
+
except ImportError:
|
|
13
|
+
logger.warning("torch is not installed, OSNet re-ID cannot be used.")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class OsnetReIDModel:
|
|
17
|
+
def __init__(self, weights_path: Path = None, device: torch.device = torch.device("cpu"), half: bool = False):
|
|
18
|
+
self.device = device
|
|
19
|
+
self.half = half
|
|
20
|
+
self.input_shape = (256, 128)
|
|
21
|
+
if weights_path is None:
|
|
22
|
+
self.model = osnet_x1_0(num_classes=1000, loss='softmax', pretrained=True, use_gpu=device)
|
|
23
|
+
else:
|
|
24
|
+
self.model = osnet_x1_0(num_classes=1000, loss='softmax', pretrained=False, use_gpu=device)
|
|
25
|
+
self.load_pretrained_weights(weights_path)
|
|
26
|
+
self.model.to(self.device).eval()
|
|
27
|
+
if self.half:
|
|
28
|
+
self.model.half()
|
|
29
|
+
|
|
30
|
+
def load_pretrained_weights(self, weight_path: Path):
|
|
31
|
+
checkpoint = torch.load(weight_path, map_location=self.device)
|
|
32
|
+
state_dict = checkpoint.get("state_dict", checkpoint)
|
|
33
|
+
model_dict = self.model.state_dict()
|
|
34
|
+
|
|
35
|
+
new_state_dict = OrderedDict()
|
|
36
|
+
for k, v in state_dict.items():
|
|
37
|
+
key = k[7:] if k.startswith("module.") else k
|
|
38
|
+
if key in model_dict and model_dict[key].size() == v.size():
|
|
39
|
+
new_state_dict[key] = v
|
|
40
|
+
|
|
41
|
+
model_dict.update(new_state_dict)
|
|
42
|
+
self.model.load_state_dict(model_dict)
|
|
43
|
+
|
|
44
|
+
def get_features(self, xyxys, img: np.ndarray):
|
|
45
|
+
if xyxys.size == 0:
|
|
46
|
+
return np.empty((0, 512))
|
|
47
|
+
|
|
48
|
+
crops = self._get_crops(xyxys, img)
|
|
49
|
+
with torch.no_grad():
|
|
50
|
+
features = self.model(crops)
|
|
51
|
+
features = F.normalize(features, dim=1).cpu().numpy()
|
|
52
|
+
|
|
53
|
+
return features
|
|
54
|
+
|
|
55
|
+
def _get_crops(self, xyxys, img):
|
|
56
|
+
h, w = img.shape[:2]
|
|
57
|
+
crops = []
|
|
58
|
+
|
|
59
|
+
for box in xyxys:
|
|
60
|
+
x1, y1, x2, y2 = box.round().astype(int)
|
|
61
|
+
x1, y1, x2, y2 = max(0, x1), max(0, y1), min(w, x2), min(h, y2)
|
|
62
|
+
crop = cv2.resize(img[y1:y2, x1:x2], self.input_shape[::-1])
|
|
63
|
+
crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
|
|
64
|
+
crop = crop.astype(np.float32) / 255.0
|
|
65
|
+
crop = (crop - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
|
|
66
|
+
crop = torch.from_numpy(crop).permute(2, 0, 1)
|
|
67
|
+
crops.append(crop)
|
|
68
|
+
|
|
69
|
+
batch = torch.stack(crops).to(self.device, dtype=torch.float16 if self.half else torch.float32)
|
|
70
|
+
return batch
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class OsnetReIDInterface:
|
|
75
|
+
def __init__(self, weights: Path, device: str = "cpu", fp16: bool = False):
|
|
76
|
+
self.device = torch.device(device)
|
|
77
|
+
self.fp16 = fp16
|
|
78
|
+
self.model = OsnetReIDModel(weights, self.device, half=fp16)
|
|
79
|
+
|
|
80
|
+
def inference(self, image: np.ndarray, detections: np.ndarray) -> np.ndarray:
|
|
81
|
+
if detections is None or np.size(detections) == 0:
|
|
82
|
+
return np.zeros((0, 512), dtype=np.float32) # пустой набор фичей
|
|
83
|
+
|
|
84
|
+
xyxys = detections[:, 0:4] # left, top, right, bottom
|
|
85
|
+
features = self.model.get_features(xyxys, image)
|
|
86
|
+
return features
|
|
87
|
+
|
|
88
|
+
|
|
File without changes
|
|
@@ -1,60 +1,59 @@
|
|
|
1
|
-
import copy
|
|
2
|
-
import time
|
|
3
|
-
|
|
4
1
|
import cv2
|
|
5
2
|
import matplotlib.pyplot as plt
|
|
6
3
|
import numpy as np
|
|
4
|
+
import copy
|
|
5
|
+
import time
|
|
7
6
|
|
|
8
7
|
|
|
9
8
|
class GMC:
|
|
10
|
-
def __init__(self, method=
|
|
9
|
+
def __init__(self, method='sparseOptFlow', downscale=2, verbose=None):
|
|
11
10
|
super(GMC, self).__init__()
|
|
12
11
|
|
|
13
12
|
self.method = method
|
|
14
13
|
self.downscale = max(1, int(downscale))
|
|
15
14
|
|
|
16
|
-
if self.method ==
|
|
15
|
+
if self.method == 'orb':
|
|
17
16
|
self.detector = cv2.FastFeatureDetector_create(20)
|
|
18
17
|
self.extractor = cv2.ORB_create()
|
|
19
18
|
self.matcher = cv2.BFMatcher(cv2.NORM_HAMMING)
|
|
20
19
|
|
|
21
|
-
elif self.method ==
|
|
22
|
-
self.detector = cv2.SIFT_create(
|
|
23
|
-
|
|
24
|
-
)
|
|
25
|
-
self.extractor = cv2.SIFT_create(
|
|
26
|
-
nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20
|
|
27
|
-
)
|
|
20
|
+
elif self.method == 'sift':
|
|
21
|
+
self.detector = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20)
|
|
22
|
+
self.extractor = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20)
|
|
28
23
|
self.matcher = cv2.BFMatcher(cv2.NORM_L2)
|
|
29
24
|
|
|
30
|
-
elif self.method ==
|
|
25
|
+
elif self.method == 'ecc':
|
|
31
26
|
number_of_iterations = 5000
|
|
32
27
|
termination_eps = 1e-6
|
|
33
28
|
self.warp_mode = cv2.MOTION_EUCLIDEAN
|
|
34
|
-
self.criteria = (
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
elif self.method == "sparseOptFlow":
|
|
41
|
-
self.feature_params = dict(
|
|
42
|
-
maxCorners=1000,
|
|
43
|
-
qualityLevel=0.01,
|
|
44
|
-
minDistance=1,
|
|
45
|
-
blockSize=3,
|
|
46
|
-
useHarrisDetector=False,
|
|
47
|
-
k=0.04,
|
|
48
|
-
)
|
|
29
|
+
self.criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, number_of_iterations, termination_eps)
|
|
30
|
+
|
|
31
|
+
elif self.method == 'sparseOptFlow':
|
|
32
|
+
self.feature_params = dict(maxCorners=1000, qualityLevel=0.01, minDistance=1, blockSize=3,
|
|
33
|
+
useHarrisDetector=False, k=0.04)
|
|
49
34
|
# self.gmc_file = open('GMC_results.txt', 'w')
|
|
50
35
|
|
|
51
|
-
elif self.method ==
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
36
|
+
elif self.method == 'file' or self.method == 'files':
|
|
37
|
+
seqName = verbose[0]
|
|
38
|
+
ablation = verbose[1]
|
|
39
|
+
if ablation:
|
|
40
|
+
filePath = r'tracker/GMC_files/MOT17_ablation'
|
|
41
|
+
else:
|
|
42
|
+
filePath = r'tracker/GMC_files/MOTChallenge'
|
|
43
|
+
|
|
44
|
+
if '-FRCNN' in seqName:
|
|
45
|
+
seqName = seqName[:-6]
|
|
46
|
+
elif '-DPM' in seqName:
|
|
47
|
+
seqName = seqName[:-4]
|
|
48
|
+
elif '-SDP' in seqName:
|
|
49
|
+
seqName = seqName[:-4]
|
|
50
|
+
|
|
51
|
+
self.gmcFile = open(filePath + "/GMC-" + seqName + ".txt", 'r')
|
|
55
52
|
|
|
56
|
-
|
|
57
|
-
|
|
53
|
+
if self.gmcFile is None:
|
|
54
|
+
raise ValueError("Error: Unable to open GMC file in directory:" + filePath)
|
|
55
|
+
elif self.method == 'none' or self.method == 'None':
|
|
56
|
+
self.method = 'none'
|
|
58
57
|
else:
|
|
59
58
|
raise ValueError("Error: Unknown CMC method:" + method)
|
|
60
59
|
|
|
@@ -65,15 +64,15 @@ class GMC:
|
|
|
65
64
|
self.initializedFirstFrame = False
|
|
66
65
|
|
|
67
66
|
def apply(self, raw_frame, detections=None):
|
|
68
|
-
if self.method ==
|
|
67
|
+
if self.method == 'orb' or self.method == 'sift':
|
|
69
68
|
return self.applyFeaures(raw_frame, detections)
|
|
70
|
-
elif self.method ==
|
|
69
|
+
elif self.method == 'ecc':
|
|
71
70
|
return self.applyEcc(raw_frame, detections)
|
|
72
|
-
elif self.method ==
|
|
71
|
+
elif self.method == 'sparseOptFlow':
|
|
73
72
|
return self.applySparseOptFlow(raw_frame, detections)
|
|
74
|
-
elif self.method ==
|
|
73
|
+
elif self.method == 'file':
|
|
75
74
|
return self.applyFile(raw_frame, detections)
|
|
76
|
-
elif self.method ==
|
|
75
|
+
elif self.method == 'none':
|
|
77
76
|
return np.eye(2, 3)
|
|
78
77
|
else:
|
|
79
78
|
return np.eye(2, 3)
|
|
@@ -105,11 +104,9 @@ class GMC:
|
|
|
105
104
|
# Run the ECC algorithm. The results are stored in warp_matrix.
|
|
106
105
|
# (cc, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria)
|
|
107
106
|
try:
|
|
108
|
-
(cc, H) = cv2.findTransformECC(
|
|
109
|
-
self.prevFrame, frame, H, self.warp_mode, self.criteria, None, 1
|
|
110
|
-
)
|
|
107
|
+
(cc, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria, None, 1)
|
|
111
108
|
except:
|
|
112
|
-
print(
|
|
109
|
+
print('Warning: find transform failed. Set warp as identity')
|
|
113
110
|
|
|
114
111
|
return H
|
|
115
112
|
|
|
@@ -130,11 +127,11 @@ class GMC:
|
|
|
130
127
|
# find the keypoints
|
|
131
128
|
mask = np.zeros_like(frame)
|
|
132
129
|
# mask[int(0.05 * height): int(0.95 * height), int(0.05 * width): int(0.95 * width)] = 255
|
|
133
|
-
mask[int(0.02 * height)
|
|
130
|
+
mask[int(0.02 * height): int(0.98 * height), int(0.02 * width): int(0.98 * width)] = 255
|
|
134
131
|
if detections is not None:
|
|
135
132
|
for det in detections:
|
|
136
133
|
tlbr = (det[:4] / self.downscale).astype(np.int_)
|
|
137
|
-
mask[tlbr[1]
|
|
134
|
+
mask[tlbr[1]:tlbr[3], tlbr[0]:tlbr[2]] = 0
|
|
138
135
|
|
|
139
136
|
keypoints = self.detector.detect(frame, mask)
|
|
140
137
|
|
|
@@ -176,14 +173,11 @@ class GMC:
|
|
|
176
173
|
prevKeyPointLocation = self.prevKeyPoints[m.queryIdx].pt
|
|
177
174
|
currKeyPointLocation = keypoints[m.trainIdx].pt
|
|
178
175
|
|
|
179
|
-
spatialDistance = (
|
|
180
|
-
|
|
181
|
-
prevKeyPointLocation[1] - currKeyPointLocation[1],
|
|
182
|
-
)
|
|
176
|
+
spatialDistance = (prevKeyPointLocation[0] - currKeyPointLocation[0],
|
|
177
|
+
prevKeyPointLocation[1] - currKeyPointLocation[1])
|
|
183
178
|
|
|
184
|
-
if (np.abs(spatialDistance[0]) < maxSpatialDistance[0]) and
|
|
185
|
-
|
|
186
|
-
):
|
|
179
|
+
if (np.abs(spatialDistance[0]) < maxSpatialDistance[0]) and \
|
|
180
|
+
(np.abs(spatialDistance[1]) < maxSpatialDistance[1]):
|
|
187
181
|
spatialDistances.append(spatialDistance)
|
|
188
182
|
matches.append(m)
|
|
189
183
|
|
|
@@ -233,7 +227,7 @@ class GMC:
|
|
|
233
227
|
H[0, 2] *= self.downscale
|
|
234
228
|
H[1, 2] *= self.downscale
|
|
235
229
|
else:
|
|
236
|
-
print(
|
|
230
|
+
print('Warning: not enough matching points')
|
|
237
231
|
|
|
238
232
|
# Store to next iteration
|
|
239
233
|
self.prevFrame = frame.copy()
|
|
@@ -271,9 +265,7 @@ class GMC:
|
|
|
271
265
|
return H
|
|
272
266
|
|
|
273
267
|
# find correspondences
|
|
274
|
-
matchedKeypoints, status, err = cv2.calcOpticalFlowPyrLK(
|
|
275
|
-
self.prevFrame, frame, self.prevKeyPoints, None
|
|
276
|
-
)
|
|
268
|
+
matchedKeypoints, status, err = cv2.calcOpticalFlowPyrLK(self.prevFrame, frame, self.prevKeyPoints, None)
|
|
277
269
|
|
|
278
270
|
# leave good correspondences only
|
|
279
271
|
prevPoints = []
|
|
@@ -296,7 +288,7 @@ class GMC:
|
|
|
296
288
|
H[0, 2] *= self.downscale
|
|
297
289
|
H[1, 2] *= self.downscale
|
|
298
290
|
else:
|
|
299
|
-
print(
|
|
291
|
+
print('Warning: not enough matching points')
|
|
300
292
|
|
|
301
293
|
# Store to next iteration
|
|
302
294
|
self.prevFrame = frame.copy()
|
|
@@ -313,7 +305,7 @@ class GMC:
|
|
|
313
305
|
def applyFile(self, raw_frame, detections=None):
|
|
314
306
|
line = self.gmcFile.readline()
|
|
315
307
|
tokens = line.split("\t")
|
|
316
|
-
H = np.eye(2, 3, dtype=np.
|
|
308
|
+
H = np.eye(2, 3, dtype=np.float32)
|
|
317
309
|
H[0, 0] = float(tokens[1])
|
|
318
310
|
H[0, 1] = float(tokens[2])
|
|
319
311
|
H[0, 2] = float(tokens[3])
|
|
@@ -321,4 +313,4 @@ class GMC:
|
|
|
321
313
|
H[1, 1] = float(tokens[5])
|
|
322
314
|
H[1, 2] = float(tokens[6])
|
|
323
315
|
|
|
324
|
-
return H
|
|
316
|
+
return H
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
# vim: expandtab:ts=4:sw=4
|
|
2
2
|
import numpy as np
|
|
3
|
+
import scipy.linalg
|
|
4
|
+
|
|
3
5
|
|
|
4
6
|
"""
|
|
5
7
|
Table for the 0.95 quantile of the chi-square distribution with N degrees of
|
|
@@ -24,13 +26,13 @@ class KalmanFilter(object):
|
|
|
24
26
|
|
|
25
27
|
The 8-dimensional state space
|
|
26
28
|
|
|
27
|
-
x, y,
|
|
29
|
+
x, y, w, h, vx, vy, vw, vh
|
|
28
30
|
|
|
29
|
-
contains the bounding box center position (x, y),
|
|
31
|
+
contains the bounding box center position (x, y), width w, height h,
|
|
30
32
|
and their respective velocities.
|
|
31
33
|
|
|
32
34
|
Object motion follows a constant velocity model. The bounding box location
|
|
33
|
-
(x, y,
|
|
35
|
+
(x, y, w, h) is taken as direct observation of the state space (linear
|
|
34
36
|
observation model).
|
|
35
37
|
|
|
36
38
|
"""
|
|
@@ -56,8 +58,8 @@ class KalmanFilter(object):
|
|
|
56
58
|
Parameters
|
|
57
59
|
----------
|
|
58
60
|
measurement : ndarray
|
|
59
|
-
Bounding box coordinates (x, y,
|
|
60
|
-
|
|
61
|
+
Bounding box coordinates (x, y, w, h) with center position (x, y),
|
|
62
|
+
width w, and height h.
|
|
61
63
|
|
|
62
64
|
Returns
|
|
63
65
|
-------
|
|
@@ -72,13 +74,13 @@ class KalmanFilter(object):
|
|
|
72
74
|
mean = np.r_[mean_pos, mean_vel]
|
|
73
75
|
|
|
74
76
|
std = [
|
|
77
|
+
2 * self._std_weight_position * measurement[2],
|
|
75
78
|
2 * self._std_weight_position * measurement[3],
|
|
79
|
+
2 * self._std_weight_position * measurement[2],
|
|
76
80
|
2 * self._std_weight_position * measurement[3],
|
|
77
|
-
|
|
78
|
-
2 * self._std_weight_position * measurement[3],
|
|
81
|
+
10 * self._std_weight_velocity * measurement[2],
|
|
79
82
|
10 * self._std_weight_velocity * measurement[3],
|
|
80
|
-
10 * self._std_weight_velocity * measurement[
|
|
81
|
-
1e-5,
|
|
83
|
+
10 * self._std_weight_velocity * measurement[2],
|
|
82
84
|
10 * self._std_weight_velocity * measurement[3]]
|
|
83
85
|
covariance = np.diag(np.square(std))
|
|
84
86
|
return mean, covariance
|
|
@@ -103,18 +105,18 @@ class KalmanFilter(object):
|
|
|
103
105
|
|
|
104
106
|
"""
|
|
105
107
|
std_pos = [
|
|
108
|
+
self._std_weight_position * mean[2],
|
|
106
109
|
self._std_weight_position * mean[3],
|
|
107
|
-
self._std_weight_position * mean[
|
|
108
|
-
1e-2,
|
|
110
|
+
self._std_weight_position * mean[2],
|
|
109
111
|
self._std_weight_position * mean[3]]
|
|
110
112
|
std_vel = [
|
|
113
|
+
self._std_weight_velocity * mean[2],
|
|
111
114
|
self._std_weight_velocity * mean[3],
|
|
112
|
-
self._std_weight_velocity * mean[
|
|
113
|
-
1e-5,
|
|
115
|
+
self._std_weight_velocity * mean[2],
|
|
114
116
|
self._std_weight_velocity * mean[3]]
|
|
115
117
|
motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
|
|
116
118
|
|
|
117
|
-
mean = np.dot(self._motion_mat
|
|
119
|
+
mean = np.dot(mean, self._motion_mat.T)
|
|
118
120
|
covariance = np.linalg.multi_dot((
|
|
119
121
|
self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
|
|
120
122
|
|
|
@@ -138,9 +140,9 @@ class KalmanFilter(object):
|
|
|
138
140
|
|
|
139
141
|
"""
|
|
140
142
|
std = [
|
|
143
|
+
self._std_weight_position * mean[2],
|
|
141
144
|
self._std_weight_position * mean[3],
|
|
142
|
-
self._std_weight_position * mean[
|
|
143
|
-
1e-1,
|
|
145
|
+
self._std_weight_position * mean[2],
|
|
144
146
|
self._std_weight_position * mean[3]]
|
|
145
147
|
innovation_cov = np.diag(np.square(std))
|
|
146
148
|
|
|
@@ -149,6 +151,45 @@ class KalmanFilter(object):
|
|
|
149
151
|
self._update_mat, covariance, self._update_mat.T))
|
|
150
152
|
return mean, covariance + innovation_cov
|
|
151
153
|
|
|
154
|
+
def multi_predict(self, mean, covariance):
|
|
155
|
+
"""Run Kalman filter prediction step (Vectorized version).
|
|
156
|
+
Parameters
|
|
157
|
+
----------
|
|
158
|
+
mean : ndarray
|
|
159
|
+
The Nx8 dimensional mean matrix of the object states at the previous
|
|
160
|
+
time step.
|
|
161
|
+
covariance : ndarray
|
|
162
|
+
The Nx8x8 dimensional covariance matrics of the object states at the
|
|
163
|
+
previous time step.
|
|
164
|
+
Returns
|
|
165
|
+
-------
|
|
166
|
+
(ndarray, ndarray)
|
|
167
|
+
Returns the mean vector and covariance matrix of the predicted
|
|
168
|
+
state. Unobserved velocities are initialized to 0 mean.
|
|
169
|
+
"""
|
|
170
|
+
std_pos = [
|
|
171
|
+
self._std_weight_position * mean[:, 2],
|
|
172
|
+
self._std_weight_position * mean[:, 3],
|
|
173
|
+
self._std_weight_position * mean[:, 2],
|
|
174
|
+
self._std_weight_position * mean[:, 3]]
|
|
175
|
+
std_vel = [
|
|
176
|
+
self._std_weight_velocity * mean[:, 2],
|
|
177
|
+
self._std_weight_velocity * mean[:, 3],
|
|
178
|
+
self._std_weight_velocity * mean[:, 2],
|
|
179
|
+
self._std_weight_velocity * mean[:, 3]]
|
|
180
|
+
sqr = np.square(np.r_[std_pos, std_vel]).T
|
|
181
|
+
|
|
182
|
+
motion_cov = []
|
|
183
|
+
for i in range(len(mean)):
|
|
184
|
+
motion_cov.append(np.diag(sqr[i]))
|
|
185
|
+
motion_cov = np.asarray(motion_cov)
|
|
186
|
+
|
|
187
|
+
mean = np.dot(mean, self._motion_mat.T)
|
|
188
|
+
left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2))
|
|
189
|
+
covariance = np.dot(left, self._motion_mat.T) + motion_cov
|
|
190
|
+
|
|
191
|
+
return mean, covariance
|
|
192
|
+
|
|
152
193
|
def update(self, mean, covariance, measurement):
|
|
153
194
|
"""Run Kalman filter correction step.
|
|
154
195
|
|
|
@@ -159,8 +200,8 @@ class KalmanFilter(object):
|
|
|
159
200
|
covariance : ndarray
|
|
160
201
|
The state's covariance matrix (8x8 dimensional).
|
|
161
202
|
measurement : ndarray
|
|
162
|
-
The 4 dimensional measurement vector (x, y,
|
|
163
|
-
is the center position,
|
|
203
|
+
The 4 dimensional measurement vector (x, y, w, h), where (x, y)
|
|
204
|
+
is the center position, w the width, and h the height of the
|
|
164
205
|
bounding box.
|
|
165
206
|
|
|
166
207
|
Returns
|
|
@@ -169,8 +210,6 @@ class KalmanFilter(object):
|
|
|
169
210
|
Returns the measurement-corrected state distribution.
|
|
170
211
|
|
|
171
212
|
"""
|
|
172
|
-
import scipy.linalg # pylint: disable=import-error
|
|
173
|
-
|
|
174
213
|
projected_mean, projected_cov = self.project(mean, covariance)
|
|
175
214
|
|
|
176
215
|
chol_factor, lower = scipy.linalg.cho_factor(
|
|
@@ -186,13 +225,11 @@ class KalmanFilter(object):
|
|
|
186
225
|
return new_mean, new_covariance
|
|
187
226
|
|
|
188
227
|
def gating_distance(self, mean, covariance, measurements,
|
|
189
|
-
only_position=False):
|
|
228
|
+
only_position=False, metric='maha'):
|
|
190
229
|
"""Compute gating distance between state distribution and measurements.
|
|
191
|
-
|
|
192
230
|
A suitable distance threshold can be obtained from `chi2inv95`. If
|
|
193
231
|
`only_position` is False, the chi-square distribution has 4 degrees of
|
|
194
232
|
freedom, otherwise 2.
|
|
195
|
-
|
|
196
233
|
Parameters
|
|
197
234
|
----------
|
|
198
235
|
mean : ndarray
|
|
@@ -206,26 +243,27 @@ class KalmanFilter(object):
|
|
|
206
243
|
only_position : Optional[bool]
|
|
207
244
|
If True, distance computation is done with respect to the bounding
|
|
208
245
|
box center position only.
|
|
209
|
-
|
|
210
246
|
Returns
|
|
211
247
|
-------
|
|
212
248
|
ndarray
|
|
213
249
|
Returns an array of length N, where the i-th element contains the
|
|
214
250
|
squared Mahalanobis distance between (mean, covariance) and
|
|
215
251
|
`measurements[i]`.
|
|
216
|
-
|
|
217
252
|
"""
|
|
218
|
-
import scipy.linalg # pylint: disable=import-error
|
|
219
|
-
|
|
220
253
|
mean, covariance = self.project(mean, covariance)
|
|
221
254
|
if only_position:
|
|
222
255
|
mean, covariance = mean[:2], covariance[:2, :2]
|
|
223
256
|
measurements = measurements[:, :2]
|
|
224
257
|
|
|
225
|
-
cholesky_factor = np.linalg.cholesky(covariance)
|
|
226
258
|
d = measurements - mean
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
259
|
+
if metric == 'gaussian':
|
|
260
|
+
return np.sum(d * d, axis=1)
|
|
261
|
+
elif metric == 'maha':
|
|
262
|
+
cholesky_factor = np.linalg.cholesky(covariance)
|
|
263
|
+
z = scipy.linalg.solve_triangular(
|
|
264
|
+
cholesky_factor, d.T, lower=True, check_finite=False,
|
|
265
|
+
overwrite_b=True)
|
|
266
|
+
squared_maha = np.sum(z * z, axis=0)
|
|
267
|
+
return squared_maha
|
|
268
|
+
else:
|
|
269
|
+
raise ValueError('invalid distance metric')
|