supervisely 6.73.418__py3-none-any.whl → 6.73.420__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. supervisely/api/entity_annotation/figure_api.py +89 -45
  2. supervisely/nn/inference/inference.py +61 -45
  3. supervisely/nn/inference/instance_segmentation/instance_segmentation.py +1 -0
  4. supervisely/nn/inference/object_detection/object_detection.py +1 -0
  5. supervisely/nn/inference/session.py +4 -4
  6. supervisely/nn/model/model_api.py +31 -20
  7. supervisely/nn/model/prediction.py +11 -0
  8. supervisely/nn/model/prediction_session.py +33 -6
  9. supervisely/nn/tracker/__init__.py +1 -2
  10. supervisely/nn/tracker/base_tracker.py +44 -0
  11. supervisely/nn/tracker/botsort/__init__.py +1 -0
  12. supervisely/nn/tracker/botsort/botsort_config.yaml +31 -0
  13. supervisely/nn/tracker/botsort/osnet_reid/osnet.py +566 -0
  14. supervisely/nn/tracker/botsort/osnet_reid/osnet_reid_interface.py +88 -0
  15. supervisely/nn/tracker/botsort/tracker/__init__.py +0 -0
  16. supervisely/nn/tracker/{bot_sort → botsort/tracker}/basetrack.py +1 -2
  17. supervisely/nn/tracker/{utils → botsort/tracker}/gmc.py +51 -59
  18. supervisely/nn/tracker/{deep_sort/deep_sort → botsort/tracker}/kalman_filter.py +71 -33
  19. supervisely/nn/tracker/botsort/tracker/matching.py +202 -0
  20. supervisely/nn/tracker/{bot_sort/bot_sort.py → botsort/tracker/mc_bot_sort.py} +68 -81
  21. supervisely/nn/tracker/botsort_tracker.py +259 -0
  22. supervisely/project/project.py +212 -74
  23. {supervisely-6.73.418.dist-info → supervisely-6.73.420.dist-info}/METADATA +3 -1
  24. {supervisely-6.73.418.dist-info → supervisely-6.73.420.dist-info}/RECORD +29 -42
  25. supervisely/nn/tracker/bot_sort/__init__.py +0 -21
  26. supervisely/nn/tracker/bot_sort/fast_reid_interface.py +0 -152
  27. supervisely/nn/tracker/bot_sort/matching.py +0 -127
  28. supervisely/nn/tracker/bot_sort/sly_tracker.py +0 -401
  29. supervisely/nn/tracker/deep_sort/__init__.py +0 -6
  30. supervisely/nn/tracker/deep_sort/deep_sort/__init__.py +0 -1
  31. supervisely/nn/tracker/deep_sort/deep_sort/detection.py +0 -49
  32. supervisely/nn/tracker/deep_sort/deep_sort/iou_matching.py +0 -81
  33. supervisely/nn/tracker/deep_sort/deep_sort/linear_assignment.py +0 -202
  34. supervisely/nn/tracker/deep_sort/deep_sort/nn_matching.py +0 -176
  35. supervisely/nn/tracker/deep_sort/deep_sort/track.py +0 -166
  36. supervisely/nn/tracker/deep_sort/deep_sort/tracker.py +0 -145
  37. supervisely/nn/tracker/deep_sort/deep_sort.py +0 -301
  38. supervisely/nn/tracker/deep_sort/generate_clip_detections.py +0 -90
  39. supervisely/nn/tracker/deep_sort/preprocessing.py +0 -70
  40. supervisely/nn/tracker/deep_sort/sly_tracker.py +0 -273
  41. supervisely/nn/tracker/tracker.py +0 -285
  42. supervisely/nn/tracker/utils/kalman_filter.py +0 -492
  43. supervisely/nn/tracking/__init__.py +0 -1
  44. supervisely/nn/tracking/boxmot.py +0 -114
  45. supervisely/nn/tracking/tracking.py +0 -24
  46. /supervisely/nn/tracker/{utils → botsort/osnet_reid}/__init__.py +0 -0
  47. {supervisely-6.73.418.dist-info → supervisely-6.73.420.dist-info}/LICENSE +0 -0
  48. {supervisely-6.73.418.dist-info → supervisely-6.73.420.dist-info}/WHEEL +0 -0
  49. {supervisely-6.73.418.dist-info → supervisely-6.73.420.dist-info}/entry_points.txt +0 -0
  50. {supervisely-6.73.418.dist-info → supervisely-6.73.420.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,88 @@
1
+ from pathlib import Path
2
+ import cv2
3
+ import numpy as np
4
+ from .osnet import osnet_x1_0
5
+ from collections import OrderedDict
6
+ from supervisely import logger
7
+
8
+ try:
9
+ # pylint: disable=import-error
10
+ import torch
11
+ from torch.nn import functional as F
12
+ except ImportError:
13
+ logger.warning("torch is not installed, OSNet re-ID cannot be used.")
14
+
15
+
16
+ class OsnetReIDModel:
17
+ def __init__(self, weights_path: Path = None, device: torch.device = torch.device("cpu"), half: bool = False):
18
+ self.device = device
19
+ self.half = half
20
+ self.input_shape = (256, 128)
21
+ if weights_path is None:
22
+ self.model = osnet_x1_0(num_classes=1000, loss='softmax', pretrained=True, use_gpu=device)
23
+ else:
24
+ self.model = osnet_x1_0(num_classes=1000, loss='softmax', pretrained=False, use_gpu=device)
25
+ self.load_pretrained_weights(weights_path)
26
+ self.model.to(self.device).eval()
27
+ if self.half:
28
+ self.model.half()
29
+
30
+ def load_pretrained_weights(self, weight_path: Path):
31
+ checkpoint = torch.load(weight_path, map_location=self.device)
32
+ state_dict = checkpoint.get("state_dict", checkpoint)
33
+ model_dict = self.model.state_dict()
34
+
35
+ new_state_dict = OrderedDict()
36
+ for k, v in state_dict.items():
37
+ key = k[7:] if k.startswith("module.") else k
38
+ if key in model_dict and model_dict[key].size() == v.size():
39
+ new_state_dict[key] = v
40
+
41
+ model_dict.update(new_state_dict)
42
+ self.model.load_state_dict(model_dict)
43
+
44
+ def get_features(self, xyxys, img: np.ndarray):
45
+ if xyxys.size == 0:
46
+ return np.empty((0, 512))
47
+
48
+ crops = self._get_crops(xyxys, img)
49
+ with torch.no_grad():
50
+ features = self.model(crops)
51
+ features = F.normalize(features, dim=1).cpu().numpy()
52
+
53
+ return features
54
+
55
+ def _get_crops(self, xyxys, img):
56
+ h, w = img.shape[:2]
57
+ crops = []
58
+
59
+ for box in xyxys:
60
+ x1, y1, x2, y2 = box.round().astype(int)
61
+ x1, y1, x2, y2 = max(0, x1), max(0, y1), min(w, x2), min(h, y2)
62
+ crop = cv2.resize(img[y1:y2, x1:x2], self.input_shape[::-1])
63
+ crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
64
+ crop = crop.astype(np.float32) / 255.0
65
+ crop = (crop - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
66
+ crop = torch.from_numpy(crop).permute(2, 0, 1)
67
+ crops.append(crop)
68
+
69
+ batch = torch.stack(crops).to(self.device, dtype=torch.float16 if self.half else torch.float32)
70
+ return batch
71
+
72
+
73
+
74
+ class OsnetReIDInterface:
75
+ def __init__(self, weights: Path, device: str = "cpu", fp16: bool = False):
76
+ self.device = torch.device(device)
77
+ self.fp16 = fp16
78
+ self.model = OsnetReIDModel(weights, self.device, half=fp16)
79
+
80
+ def inference(self, image: np.ndarray, detections: np.ndarray) -> np.ndarray:
81
+ if detections is None or np.size(detections) == 0:
82
+ return np.zeros((0, 512), dtype=np.float32) # пустой набор фичей
83
+
84
+ xyxys = detections[:, 0:4] # left, top, right, bottom
85
+ features = self.model.get_features(xyxys, image)
86
+ return features
87
+
88
+
File without changes
@@ -1,6 +1,5 @@
1
- from collections import OrderedDict
2
-
3
1
  import numpy as np
2
+ from collections import OrderedDict
4
3
 
5
4
 
6
5
  class TrackState(object):
@@ -1,60 +1,59 @@
1
- import copy
2
- import time
3
-
4
1
  import cv2
5
2
  import matplotlib.pyplot as plt
6
3
  import numpy as np
4
+ import copy
5
+ import time
7
6
 
8
7
 
9
8
  class GMC:
10
- def __init__(self, method="sparseOptFlow", downscale=2, gmc_file: str = None):
9
+ def __init__(self, method='sparseOptFlow', downscale=2, verbose=None):
11
10
  super(GMC, self).__init__()
12
11
 
13
12
  self.method = method
14
13
  self.downscale = max(1, int(downscale))
15
14
 
16
- if self.method == "orb":
15
+ if self.method == 'orb':
17
16
  self.detector = cv2.FastFeatureDetector_create(20)
18
17
  self.extractor = cv2.ORB_create()
19
18
  self.matcher = cv2.BFMatcher(cv2.NORM_HAMMING)
20
19
 
21
- elif self.method == "sift":
22
- self.detector = cv2.SIFT_create(
23
- nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20
24
- )
25
- self.extractor = cv2.SIFT_create(
26
- nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20
27
- )
20
+ elif self.method == 'sift':
21
+ self.detector = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20)
22
+ self.extractor = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20)
28
23
  self.matcher = cv2.BFMatcher(cv2.NORM_L2)
29
24
 
30
- elif self.method == "ecc":
25
+ elif self.method == 'ecc':
31
26
  number_of_iterations = 5000
32
27
  termination_eps = 1e-6
33
28
  self.warp_mode = cv2.MOTION_EUCLIDEAN
34
- self.criteria = (
35
- cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT,
36
- number_of_iterations,
37
- termination_eps,
38
- )
39
-
40
- elif self.method == "sparseOptFlow":
41
- self.feature_params = dict(
42
- maxCorners=1000,
43
- qualityLevel=0.01,
44
- minDistance=1,
45
- blockSize=3,
46
- useHarrisDetector=False,
47
- k=0.04,
48
- )
29
+ self.criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, number_of_iterations, termination_eps)
30
+
31
+ elif self.method == 'sparseOptFlow':
32
+ self.feature_params = dict(maxCorners=1000, qualityLevel=0.01, minDistance=1, blockSize=3,
33
+ useHarrisDetector=False, k=0.04)
49
34
  # self.gmc_file = open('GMC_results.txt', 'w')
50
35
 
51
- elif self.method == "file" or self.method == "files":
52
- self.gmcFile = open(gmc_file, "r")
53
- if self.gmcFile is None:
54
- raise ValueError("Error: Unable to open GMC file:" + gmc_file)
36
+ elif self.method == 'file' or self.method == 'files':
37
+ seqName = verbose[0]
38
+ ablation = verbose[1]
39
+ if ablation:
40
+ filePath = r'tracker/GMC_files/MOT17_ablation'
41
+ else:
42
+ filePath = r'tracker/GMC_files/MOTChallenge'
43
+
44
+ if '-FRCNN' in seqName:
45
+ seqName = seqName[:-6]
46
+ elif '-DPM' in seqName:
47
+ seqName = seqName[:-4]
48
+ elif '-SDP' in seqName:
49
+ seqName = seqName[:-4]
50
+
51
+ self.gmcFile = open(filePath + "/GMC-" + seqName + ".txt", 'r')
55
52
 
56
- elif self.method is None or self.method in ("None", "none"):
57
- self.method = "none"
53
+ if self.gmcFile is None:
54
+ raise ValueError("Error: Unable to open GMC file in directory:" + filePath)
55
+ elif self.method == 'none' or self.method == 'None':
56
+ self.method = 'none'
58
57
  else:
59
58
  raise ValueError("Error: Unknown CMC method:" + method)
60
59
 
@@ -65,15 +64,15 @@ class GMC:
65
64
  self.initializedFirstFrame = False
66
65
 
67
66
  def apply(self, raw_frame, detections=None):
68
- if self.method == "orb" or self.method == "sift":
67
+ if self.method == 'orb' or self.method == 'sift':
69
68
  return self.applyFeaures(raw_frame, detections)
70
- elif self.method == "ecc":
69
+ elif self.method == 'ecc':
71
70
  return self.applyEcc(raw_frame, detections)
72
- elif self.method == "sparseOptFlow":
71
+ elif self.method == 'sparseOptFlow':
73
72
  return self.applySparseOptFlow(raw_frame, detections)
74
- elif self.method == "file":
73
+ elif self.method == 'file':
75
74
  return self.applyFile(raw_frame, detections)
76
- elif self.method == "none":
75
+ elif self.method == 'none':
77
76
  return np.eye(2, 3)
78
77
  else:
79
78
  return np.eye(2, 3)
@@ -105,11 +104,9 @@ class GMC:
105
104
  # Run the ECC algorithm. The results are stored in warp_matrix.
106
105
  # (cc, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria)
107
106
  try:
108
- (cc, H) = cv2.findTransformECC(
109
- self.prevFrame, frame, H, self.warp_mode, self.criteria, None, 1
110
- )
107
+ (cc, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria, None, 1)
111
108
  except:
112
- print("Warning: find transform failed. Set warp as identity")
109
+ print('Warning: find transform failed. Set warp as identity')
113
110
 
114
111
  return H
115
112
 
@@ -130,11 +127,11 @@ class GMC:
130
127
  # find the keypoints
131
128
  mask = np.zeros_like(frame)
132
129
  # mask[int(0.05 * height): int(0.95 * height), int(0.05 * width): int(0.95 * width)] = 255
133
- mask[int(0.02 * height) : int(0.98 * height), int(0.02 * width) : int(0.98 * width)] = 255
130
+ mask[int(0.02 * height): int(0.98 * height), int(0.02 * width): int(0.98 * width)] = 255
134
131
  if detections is not None:
135
132
  for det in detections:
136
133
  tlbr = (det[:4] / self.downscale).astype(np.int_)
137
- mask[tlbr[1] : tlbr[3], tlbr[0] : tlbr[2]] = 0
134
+ mask[tlbr[1]:tlbr[3], tlbr[0]:tlbr[2]] = 0
138
135
 
139
136
  keypoints = self.detector.detect(frame, mask)
140
137
 
@@ -176,14 +173,11 @@ class GMC:
176
173
  prevKeyPointLocation = self.prevKeyPoints[m.queryIdx].pt
177
174
  currKeyPointLocation = keypoints[m.trainIdx].pt
178
175
 
179
- spatialDistance = (
180
- prevKeyPointLocation[0] - currKeyPointLocation[0],
181
- prevKeyPointLocation[1] - currKeyPointLocation[1],
182
- )
176
+ spatialDistance = (prevKeyPointLocation[0] - currKeyPointLocation[0],
177
+ prevKeyPointLocation[1] - currKeyPointLocation[1])
183
178
 
184
- if (np.abs(spatialDistance[0]) < maxSpatialDistance[0]) and (
185
- np.abs(spatialDistance[1]) < maxSpatialDistance[1]
186
- ):
179
+ if (np.abs(spatialDistance[0]) < maxSpatialDistance[0]) and \
180
+ (np.abs(spatialDistance[1]) < maxSpatialDistance[1]):
187
181
  spatialDistances.append(spatialDistance)
188
182
  matches.append(m)
189
183
 
@@ -233,7 +227,7 @@ class GMC:
233
227
  H[0, 2] *= self.downscale
234
228
  H[1, 2] *= self.downscale
235
229
  else:
236
- print("Warning: not enough matching points")
230
+ print('Warning: not enough matching points')
237
231
 
238
232
  # Store to next iteration
239
233
  self.prevFrame = frame.copy()
@@ -271,9 +265,7 @@ class GMC:
271
265
  return H
272
266
 
273
267
  # find correspondences
274
- matchedKeypoints, status, err = cv2.calcOpticalFlowPyrLK(
275
- self.prevFrame, frame, self.prevKeyPoints, None
276
- )
268
+ matchedKeypoints, status, err = cv2.calcOpticalFlowPyrLK(self.prevFrame, frame, self.prevKeyPoints, None)
277
269
 
278
270
  # leave good correspondences only
279
271
  prevPoints = []
@@ -296,7 +288,7 @@ class GMC:
296
288
  H[0, 2] *= self.downscale
297
289
  H[1, 2] *= self.downscale
298
290
  else:
299
- print("Warning: not enough matching points")
291
+ print('Warning: not enough matching points')
300
292
 
301
293
  # Store to next iteration
302
294
  self.prevFrame = frame.copy()
@@ -313,7 +305,7 @@ class GMC:
313
305
  def applyFile(self, raw_frame, detections=None):
314
306
  line = self.gmcFile.readline()
315
307
  tokens = line.split("\t")
316
- H = np.eye(2, 3, dtype=np.float_)
308
+ H = np.eye(2, 3, dtype=np.float32)
317
309
  H[0, 0] = float(tokens[1])
318
310
  H[0, 1] = float(tokens[2])
319
311
  H[0, 2] = float(tokens[3])
@@ -321,4 +313,4 @@ class GMC:
321
313
  H[1, 1] = float(tokens[5])
322
314
  H[1, 2] = float(tokens[6])
323
315
 
324
- return H
316
+ return H
@@ -1,5 +1,7 @@
1
1
  # vim: expandtab:ts=4:sw=4
2
2
  import numpy as np
3
+ import scipy.linalg
4
+
3
5
 
4
6
  """
5
7
  Table for the 0.95 quantile of the chi-square distribution with N degrees of
@@ -24,13 +26,13 @@ class KalmanFilter(object):
24
26
 
25
27
  The 8-dimensional state space
26
28
 
27
- x, y, a, h, vx, vy, va, vh
29
+ x, y, w, h, vx, vy, vw, vh
28
30
 
29
- contains the bounding box center position (x, y), aspect ratio a, height h,
31
+ contains the bounding box center position (x, y), width w, height h,
30
32
  and their respective velocities.
31
33
 
32
34
  Object motion follows a constant velocity model. The bounding box location
33
- (x, y, a, h) is taken as direct observation of the state space (linear
35
+ (x, y, w, h) is taken as direct observation of the state space (linear
34
36
  observation model).
35
37
 
36
38
  """
@@ -56,8 +58,8 @@ class KalmanFilter(object):
56
58
  Parameters
57
59
  ----------
58
60
  measurement : ndarray
59
- Bounding box coordinates (x, y, a, h) with center position (x, y),
60
- aspect ratio a, and height h.
61
+ Bounding box coordinates (x, y, w, h) with center position (x, y),
62
+ width w, and height h.
61
63
 
62
64
  Returns
63
65
  -------
@@ -72,13 +74,13 @@ class KalmanFilter(object):
72
74
  mean = np.r_[mean_pos, mean_vel]
73
75
 
74
76
  std = [
77
+ 2 * self._std_weight_position * measurement[2],
75
78
  2 * self._std_weight_position * measurement[3],
79
+ 2 * self._std_weight_position * measurement[2],
76
80
  2 * self._std_weight_position * measurement[3],
77
- 1e-2,
78
- 2 * self._std_weight_position * measurement[3],
81
+ 10 * self._std_weight_velocity * measurement[2],
79
82
  10 * self._std_weight_velocity * measurement[3],
80
- 10 * self._std_weight_velocity * measurement[3],
81
- 1e-5,
83
+ 10 * self._std_weight_velocity * measurement[2],
82
84
  10 * self._std_weight_velocity * measurement[3]]
83
85
  covariance = np.diag(np.square(std))
84
86
  return mean, covariance
@@ -103,18 +105,18 @@ class KalmanFilter(object):
103
105
 
104
106
  """
105
107
  std_pos = [
108
+ self._std_weight_position * mean[2],
106
109
  self._std_weight_position * mean[3],
107
- self._std_weight_position * mean[3],
108
- 1e-2,
110
+ self._std_weight_position * mean[2],
109
111
  self._std_weight_position * mean[3]]
110
112
  std_vel = [
113
+ self._std_weight_velocity * mean[2],
111
114
  self._std_weight_velocity * mean[3],
112
- self._std_weight_velocity * mean[3],
113
- 1e-5,
115
+ self._std_weight_velocity * mean[2],
114
116
  self._std_weight_velocity * mean[3]]
115
117
  motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
116
118
 
117
- mean = np.dot(self._motion_mat, mean)
119
+ mean = np.dot(mean, self._motion_mat.T)
118
120
  covariance = np.linalg.multi_dot((
119
121
  self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
120
122
 
@@ -138,9 +140,9 @@ class KalmanFilter(object):
138
140
 
139
141
  """
140
142
  std = [
143
+ self._std_weight_position * mean[2],
141
144
  self._std_weight_position * mean[3],
142
- self._std_weight_position * mean[3],
143
- 1e-1,
145
+ self._std_weight_position * mean[2],
144
146
  self._std_weight_position * mean[3]]
145
147
  innovation_cov = np.diag(np.square(std))
146
148
 
@@ -149,6 +151,45 @@ class KalmanFilter(object):
149
151
  self._update_mat, covariance, self._update_mat.T))
150
152
  return mean, covariance + innovation_cov
151
153
 
154
+ def multi_predict(self, mean, covariance):
155
+ """Run Kalman filter prediction step (Vectorized version).
156
+ Parameters
157
+ ----------
158
+ mean : ndarray
159
+ The Nx8 dimensional mean matrix of the object states at the previous
160
+ time step.
161
+ covariance : ndarray
162
+ The Nx8x8 dimensional covariance matrics of the object states at the
163
+ previous time step.
164
+ Returns
165
+ -------
166
+ (ndarray, ndarray)
167
+ Returns the mean vector and covariance matrix of the predicted
168
+ state. Unobserved velocities are initialized to 0 mean.
169
+ """
170
+ std_pos = [
171
+ self._std_weight_position * mean[:, 2],
172
+ self._std_weight_position * mean[:, 3],
173
+ self._std_weight_position * mean[:, 2],
174
+ self._std_weight_position * mean[:, 3]]
175
+ std_vel = [
176
+ self._std_weight_velocity * mean[:, 2],
177
+ self._std_weight_velocity * mean[:, 3],
178
+ self._std_weight_velocity * mean[:, 2],
179
+ self._std_weight_velocity * mean[:, 3]]
180
+ sqr = np.square(np.r_[std_pos, std_vel]).T
181
+
182
+ motion_cov = []
183
+ for i in range(len(mean)):
184
+ motion_cov.append(np.diag(sqr[i]))
185
+ motion_cov = np.asarray(motion_cov)
186
+
187
+ mean = np.dot(mean, self._motion_mat.T)
188
+ left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2))
189
+ covariance = np.dot(left, self._motion_mat.T) + motion_cov
190
+
191
+ return mean, covariance
192
+
152
193
  def update(self, mean, covariance, measurement):
153
194
  """Run Kalman filter correction step.
154
195
 
@@ -159,8 +200,8 @@ class KalmanFilter(object):
159
200
  covariance : ndarray
160
201
  The state's covariance matrix (8x8 dimensional).
161
202
  measurement : ndarray
162
- The 4 dimensional measurement vector (x, y, a, h), where (x, y)
163
- is the center position, a the aspect ratio, and h the height of the
203
+ The 4 dimensional measurement vector (x, y, w, h), where (x, y)
204
+ is the center position, w the width, and h the height of the
164
205
  bounding box.
165
206
 
166
207
  Returns
@@ -169,8 +210,6 @@ class KalmanFilter(object):
169
210
  Returns the measurement-corrected state distribution.
170
211
 
171
212
  """
172
- import scipy.linalg # pylint: disable=import-error
173
-
174
213
  projected_mean, projected_cov = self.project(mean, covariance)
175
214
 
176
215
  chol_factor, lower = scipy.linalg.cho_factor(
@@ -186,13 +225,11 @@ class KalmanFilter(object):
186
225
  return new_mean, new_covariance
187
226
 
188
227
  def gating_distance(self, mean, covariance, measurements,
189
- only_position=False):
228
+ only_position=False, metric='maha'):
190
229
  """Compute gating distance between state distribution and measurements.
191
-
192
230
  A suitable distance threshold can be obtained from `chi2inv95`. If
193
231
  `only_position` is False, the chi-square distribution has 4 degrees of
194
232
  freedom, otherwise 2.
195
-
196
233
  Parameters
197
234
  ----------
198
235
  mean : ndarray
@@ -206,26 +243,27 @@ class KalmanFilter(object):
206
243
  only_position : Optional[bool]
207
244
  If True, distance computation is done with respect to the bounding
208
245
  box center position only.
209
-
210
246
  Returns
211
247
  -------
212
248
  ndarray
213
249
  Returns an array of length N, where the i-th element contains the
214
250
  squared Mahalanobis distance between (mean, covariance) and
215
251
  `measurements[i]`.
216
-
217
252
  """
218
- import scipy.linalg # pylint: disable=import-error
219
-
220
253
  mean, covariance = self.project(mean, covariance)
221
254
  if only_position:
222
255
  mean, covariance = mean[:2], covariance[:2, :2]
223
256
  measurements = measurements[:, :2]
224
257
 
225
- cholesky_factor = np.linalg.cholesky(covariance)
226
258
  d = measurements - mean
227
- z = scipy.linalg.solve_triangular(
228
- cholesky_factor, d.T, lower=True, check_finite=False,
229
- overwrite_b=True)
230
- squared_maha = np.sum(z * z, axis=0)
231
- return squared_maha
259
+ if metric == 'gaussian':
260
+ return np.sum(d * d, axis=1)
261
+ elif metric == 'maha':
262
+ cholesky_factor = np.linalg.cholesky(covariance)
263
+ z = scipy.linalg.solve_triangular(
264
+ cholesky_factor, d.T, lower=True, check_finite=False,
265
+ overwrite_b=True)
266
+ squared_maha = np.sum(z * z, axis=0)
267
+ return squared_maha
268
+ else:
269
+ raise ValueError('invalid distance metric')