supervisely 6.73.418__py3-none-any.whl → 6.73.419__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supervisely/api/entity_annotation/figure_api.py +89 -45
- supervisely/nn/inference/inference.py +61 -45
- supervisely/nn/inference/instance_segmentation/instance_segmentation.py +1 -0
- supervisely/nn/inference/object_detection/object_detection.py +1 -0
- supervisely/nn/inference/session.py +4 -4
- supervisely/nn/model/model_api.py +31 -20
- supervisely/nn/model/prediction.py +11 -0
- supervisely/nn/model/prediction_session.py +33 -6
- supervisely/nn/tracker/__init__.py +1 -2
- supervisely/nn/tracker/base_tracker.py +44 -0
- supervisely/nn/tracker/botsort/__init__.py +1 -0
- supervisely/nn/tracker/botsort/botsort_config.yaml +31 -0
- supervisely/nn/tracker/botsort/osnet_reid/osnet.py +566 -0
- supervisely/nn/tracker/botsort/osnet_reid/osnet_reid_interface.py +88 -0
- supervisely/nn/tracker/botsort/tracker/__init__.py +0 -0
- supervisely/nn/tracker/{bot_sort → botsort/tracker}/basetrack.py +1 -2
- supervisely/nn/tracker/{utils → botsort/tracker}/gmc.py +51 -59
- supervisely/nn/tracker/{deep_sort/deep_sort → botsort/tracker}/kalman_filter.py +71 -33
- supervisely/nn/tracker/botsort/tracker/matching.py +202 -0
- supervisely/nn/tracker/{bot_sort/bot_sort.py → botsort/tracker/mc_bot_sort.py} +68 -81
- supervisely/nn/tracker/botsort_tracker.py +259 -0
- supervisely/project/project.py +1 -1
- {supervisely-6.73.418.dist-info → supervisely-6.73.419.dist-info}/METADATA +3 -1
- {supervisely-6.73.418.dist-info → supervisely-6.73.419.dist-info}/RECORD +29 -42
- supervisely/nn/tracker/bot_sort/__init__.py +0 -21
- supervisely/nn/tracker/bot_sort/fast_reid_interface.py +0 -152
- supervisely/nn/tracker/bot_sort/matching.py +0 -127
- supervisely/nn/tracker/bot_sort/sly_tracker.py +0 -401
- supervisely/nn/tracker/deep_sort/__init__.py +0 -6
- supervisely/nn/tracker/deep_sort/deep_sort/__init__.py +0 -1
- supervisely/nn/tracker/deep_sort/deep_sort/detection.py +0 -49
- supervisely/nn/tracker/deep_sort/deep_sort/iou_matching.py +0 -81
- supervisely/nn/tracker/deep_sort/deep_sort/linear_assignment.py +0 -202
- supervisely/nn/tracker/deep_sort/deep_sort/nn_matching.py +0 -176
- supervisely/nn/tracker/deep_sort/deep_sort/track.py +0 -166
- supervisely/nn/tracker/deep_sort/deep_sort/tracker.py +0 -145
- supervisely/nn/tracker/deep_sort/deep_sort.py +0 -301
- supervisely/nn/tracker/deep_sort/generate_clip_detections.py +0 -90
- supervisely/nn/tracker/deep_sort/preprocessing.py +0 -70
- supervisely/nn/tracker/deep_sort/sly_tracker.py +0 -273
- supervisely/nn/tracker/tracker.py +0 -285
- supervisely/nn/tracker/utils/kalman_filter.py +0 -492
- supervisely/nn/tracking/__init__.py +0 -1
- supervisely/nn/tracking/boxmot.py +0 -114
- supervisely/nn/tracking/tracking.py +0 -24
- /supervisely/nn/tracker/{utils → botsort/osnet_reid}/__init__.py +0 -0
- {supervisely-6.73.418.dist-info → supervisely-6.73.419.dist-info}/LICENSE +0 -0
- {supervisely-6.73.418.dist-info → supervisely-6.73.419.dist-info}/WHEEL +0 -0
- {supervisely-6.73.418.dist-info → supervisely-6.73.419.dist-info}/entry_points.txt +0 -0
- {supervisely-6.73.418.dist-info → supervisely-6.73.419.dist-info}/top_level.txt +0 -0
| @@ -0,0 +1,88 @@ | |
| 1 | 
            +
            from pathlib import Path
         | 
| 2 | 
            +
            import cv2
         | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
            from .osnet import osnet_x1_0
         | 
| 5 | 
            +
            from collections import OrderedDict
         | 
| 6 | 
            +
            from supervisely import logger
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            try:
         | 
| 9 | 
            +
                # pylint: disable=import-error
         | 
| 10 | 
            +
                import torch
         | 
| 11 | 
            +
                from torch.nn import functional as F
         | 
| 12 | 
            +
            except ImportError:
         | 
| 13 | 
            +
                logger.warning("torch is not installed, OSNet re-ID cannot be used.")
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            class OsnetReIDModel:
         | 
| 17 | 
            +
                def __init__(self, weights_path: Path = None, device: torch.device = torch.device("cpu"), half: bool = False):
         | 
| 18 | 
            +
                    self.device = device
         | 
| 19 | 
            +
                    self.half = half
         | 
| 20 | 
            +
                    self.input_shape = (256, 128)
         | 
| 21 | 
            +
                    if weights_path is None:
         | 
| 22 | 
            +
                        self.model = osnet_x1_0(num_classes=1000, loss='softmax', pretrained=True, use_gpu=device)
         | 
| 23 | 
            +
                    else:
         | 
| 24 | 
            +
                        self.model = osnet_x1_0(num_classes=1000, loss='softmax', pretrained=False, use_gpu=device)
         | 
| 25 | 
            +
                        self.load_pretrained_weights(weights_path)
         | 
| 26 | 
            +
                    self.model.to(self.device).eval()
         | 
| 27 | 
            +
                    if self.half:
         | 
| 28 | 
            +
                        self.model.half()
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                def load_pretrained_weights(self, weight_path: Path):
         | 
| 31 | 
            +
                    checkpoint = torch.load(weight_path, map_location=self.device)
         | 
| 32 | 
            +
                    state_dict = checkpoint.get("state_dict", checkpoint)
         | 
| 33 | 
            +
                    model_dict = self.model.state_dict()
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                    new_state_dict = OrderedDict()
         | 
| 36 | 
            +
                    for k, v in state_dict.items():
         | 
| 37 | 
            +
                        key = k[7:] if k.startswith("module.") else k
         | 
| 38 | 
            +
                        if key in model_dict and model_dict[key].size() == v.size():
         | 
| 39 | 
            +
                            new_state_dict[key] = v
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                    model_dict.update(new_state_dict)
         | 
| 42 | 
            +
                    self.model.load_state_dict(model_dict)
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                def get_features(self, xyxys, img: np.ndarray):
         | 
| 45 | 
            +
                    if xyxys.size == 0:
         | 
| 46 | 
            +
                        return np.empty((0, 512))
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                    crops = self._get_crops(xyxys, img)
         | 
| 49 | 
            +
                    with torch.no_grad():
         | 
| 50 | 
            +
                        features = self.model(crops)
         | 
| 51 | 
            +
                        features = F.normalize(features, dim=1).cpu().numpy()
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                    return features
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                def _get_crops(self, xyxys, img):
         | 
| 56 | 
            +
                    h, w = img.shape[:2]
         | 
| 57 | 
            +
                    crops = []
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    for box in xyxys:
         | 
| 60 | 
            +
                        x1, y1, x2, y2 = box.round().astype(int)
         | 
| 61 | 
            +
                        x1, y1, x2, y2 = max(0, x1), max(0, y1), min(w, x2), min(h, y2)
         | 
| 62 | 
            +
                        crop = cv2.resize(img[y1:y2, x1:x2], self.input_shape[::-1])
         | 
| 63 | 
            +
                        crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
         | 
| 64 | 
            +
                        crop = crop.astype(np.float32) / 255.0
         | 
| 65 | 
            +
                        crop = (crop - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
         | 
| 66 | 
            +
                        crop = torch.from_numpy(crop).permute(2, 0, 1)
         | 
| 67 | 
            +
                        crops.append(crop)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    batch = torch.stack(crops).to(self.device, dtype=torch.float16 if self.half else torch.float32)
         | 
| 70 | 
            +
                    return batch
         | 
| 71 | 
            +
             | 
| 72 | 
            +
             | 
| 73 | 
            +
             | 
| 74 | 
            +
            class OsnetReIDInterface:
         | 
| 75 | 
            +
                def __init__(self, weights: Path, device: str = "cpu", fp16: bool = False):
         | 
| 76 | 
            +
                    self.device = torch.device(device)
         | 
| 77 | 
            +
                    self.fp16 = fp16
         | 
| 78 | 
            +
                    self.model = OsnetReIDModel(weights, self.device, half=fp16)
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                def inference(self, image: np.ndarray, detections: np.ndarray) -> np.ndarray:
         | 
| 81 | 
            +
                    if detections is None or np.size(detections) == 0:
         | 
| 82 | 
            +
                        return np.zeros((0, 512), dtype=np.float32)  # пустой набор фичей
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    xyxys = detections[:, 0:4]  # left, top, right, bottom
         | 
| 85 | 
            +
                    features = self.model.get_features(xyxys, image)
         | 
| 86 | 
            +
                    return features
         | 
| 87 | 
            +
             | 
| 88 | 
            +
             
         | 
| 
            File without changes
         | 
| @@ -1,60 +1,59 @@ | |
| 1 | 
            -
            import copy
         | 
| 2 | 
            -
            import time
         | 
| 3 | 
            -
             | 
| 4 1 | 
             
            import cv2
         | 
| 5 2 | 
             
            import matplotlib.pyplot as plt
         | 
| 6 3 | 
             
            import numpy as np
         | 
| 4 | 
            +
            import copy
         | 
| 5 | 
            +
            import time
         | 
| 7 6 |  | 
| 8 7 |  | 
| 9 8 | 
             
            class GMC:
         | 
| 10 | 
            -
                def __init__(self, method= | 
| 9 | 
            +
                def __init__(self, method='sparseOptFlow', downscale=2, verbose=None):
         | 
| 11 10 | 
             
                    super(GMC, self).__init__()
         | 
| 12 11 |  | 
| 13 12 | 
             
                    self.method = method
         | 
| 14 13 | 
             
                    self.downscale = max(1, int(downscale))
         | 
| 15 14 |  | 
| 16 | 
            -
                    if self.method ==  | 
| 15 | 
            +
                    if self.method == 'orb':
         | 
| 17 16 | 
             
                        self.detector = cv2.FastFeatureDetector_create(20)
         | 
| 18 17 | 
             
                        self.extractor = cv2.ORB_create()
         | 
| 19 18 | 
             
                        self.matcher = cv2.BFMatcher(cv2.NORM_HAMMING)
         | 
| 20 19 |  | 
| 21 | 
            -
                    elif self.method ==  | 
| 22 | 
            -
                        self.detector = cv2.SIFT_create(
         | 
| 23 | 
            -
             | 
| 24 | 
            -
                        )
         | 
| 25 | 
            -
                        self.extractor = cv2.SIFT_create(
         | 
| 26 | 
            -
                            nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20
         | 
| 27 | 
            -
                        )
         | 
| 20 | 
            +
                    elif self.method == 'sift':
         | 
| 21 | 
            +
                        self.detector = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20)
         | 
| 22 | 
            +
                        self.extractor = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20)
         | 
| 28 23 | 
             
                        self.matcher = cv2.BFMatcher(cv2.NORM_L2)
         | 
| 29 24 |  | 
| 30 | 
            -
                    elif self.method ==  | 
| 25 | 
            +
                    elif self.method == 'ecc':
         | 
| 31 26 | 
             
                        number_of_iterations = 5000
         | 
| 32 27 | 
             
                        termination_eps = 1e-6
         | 
| 33 28 | 
             
                        self.warp_mode = cv2.MOTION_EUCLIDEAN
         | 
| 34 | 
            -
                        self.criteria = (
         | 
| 35 | 
            -
             | 
| 36 | 
            -
             | 
| 37 | 
            -
             | 
| 38 | 
            -
             | 
| 39 | 
            -
             | 
| 40 | 
            -
                    elif self.method == "sparseOptFlow":
         | 
| 41 | 
            -
                        self.feature_params = dict(
         | 
| 42 | 
            -
                            maxCorners=1000,
         | 
| 43 | 
            -
                            qualityLevel=0.01,
         | 
| 44 | 
            -
                            minDistance=1,
         | 
| 45 | 
            -
                            blockSize=3,
         | 
| 46 | 
            -
                            useHarrisDetector=False,
         | 
| 47 | 
            -
                            k=0.04,
         | 
| 48 | 
            -
                        )
         | 
| 29 | 
            +
                        self.criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, number_of_iterations, termination_eps)
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                    elif self.method == 'sparseOptFlow':
         | 
| 32 | 
            +
                        self.feature_params = dict(maxCorners=1000, qualityLevel=0.01, minDistance=1, blockSize=3,
         | 
| 33 | 
            +
                                                   useHarrisDetector=False, k=0.04)
         | 
| 49 34 | 
             
                        # self.gmc_file = open('GMC_results.txt', 'w')
         | 
| 50 35 |  | 
| 51 | 
            -
                    elif self.method ==  | 
| 52 | 
            -
                         | 
| 53 | 
            -
                         | 
| 54 | 
            -
             | 
| 36 | 
            +
                    elif self.method == 'file' or self.method == 'files':
         | 
| 37 | 
            +
                        seqName = verbose[0]
         | 
| 38 | 
            +
                        ablation = verbose[1]
         | 
| 39 | 
            +
                        if ablation:
         | 
| 40 | 
            +
                            filePath = r'tracker/GMC_files/MOT17_ablation'
         | 
| 41 | 
            +
                        else:
         | 
| 42 | 
            +
                            filePath = r'tracker/GMC_files/MOTChallenge'
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                        if '-FRCNN' in seqName:
         | 
| 45 | 
            +
                            seqName = seqName[:-6]
         | 
| 46 | 
            +
                        elif '-DPM' in seqName:
         | 
| 47 | 
            +
                            seqName = seqName[:-4]
         | 
| 48 | 
            +
                        elif '-SDP' in seqName:
         | 
| 49 | 
            +
                            seqName = seqName[:-4]
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                        self.gmcFile = open(filePath + "/GMC-" + seqName + ".txt", 'r')
         | 
| 55 52 |  | 
| 56 | 
            -
             | 
| 57 | 
            -
             | 
| 53 | 
            +
                        if self.gmcFile is None:
         | 
| 54 | 
            +
                            raise ValueError("Error: Unable to open GMC file in directory:" + filePath)
         | 
| 55 | 
            +
                    elif self.method == 'none' or self.method == 'None':
         | 
| 56 | 
            +
                        self.method = 'none'
         | 
| 58 57 | 
             
                    else:
         | 
| 59 58 | 
             
                        raise ValueError("Error: Unknown CMC method:" + method)
         | 
| 60 59 |  | 
| @@ -65,15 +64,15 @@ class GMC: | |
| 65 64 | 
             
                    self.initializedFirstFrame = False
         | 
| 66 65 |  | 
| 67 66 | 
             
                def apply(self, raw_frame, detections=None):
         | 
| 68 | 
            -
                    if self.method ==  | 
| 67 | 
            +
                    if self.method == 'orb' or self.method == 'sift':
         | 
| 69 68 | 
             
                        return self.applyFeaures(raw_frame, detections)
         | 
| 70 | 
            -
                    elif self.method ==  | 
| 69 | 
            +
                    elif self.method == 'ecc':
         | 
| 71 70 | 
             
                        return self.applyEcc(raw_frame, detections)
         | 
| 72 | 
            -
                    elif self.method ==  | 
| 71 | 
            +
                    elif self.method == 'sparseOptFlow':
         | 
| 73 72 | 
             
                        return self.applySparseOptFlow(raw_frame, detections)
         | 
| 74 | 
            -
                    elif self.method ==  | 
| 73 | 
            +
                    elif self.method == 'file':
         | 
| 75 74 | 
             
                        return self.applyFile(raw_frame, detections)
         | 
| 76 | 
            -
                    elif self.method ==  | 
| 75 | 
            +
                    elif self.method == 'none':
         | 
| 77 76 | 
             
                        return np.eye(2, 3)
         | 
| 78 77 | 
             
                    else:
         | 
| 79 78 | 
             
                        return np.eye(2, 3)
         | 
| @@ -105,11 +104,9 @@ class GMC: | |
| 105 104 | 
             
                    # Run the ECC algorithm. The results are stored in warp_matrix.
         | 
| 106 105 | 
             
                    # (cc, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria)
         | 
| 107 106 | 
             
                    try:
         | 
| 108 | 
            -
                        (cc, H) = cv2.findTransformECC(
         | 
| 109 | 
            -
                            self.prevFrame, frame, H, self.warp_mode, self.criteria, None, 1
         | 
| 110 | 
            -
                        )
         | 
| 107 | 
            +
                        (cc, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria, None, 1)
         | 
| 111 108 | 
             
                    except:
         | 
| 112 | 
            -
                        print( | 
| 109 | 
            +
                        print('Warning: find transform failed. Set warp as identity')
         | 
| 113 110 |  | 
| 114 111 | 
             
                    return H
         | 
| 115 112 |  | 
| @@ -130,11 +127,11 @@ class GMC: | |
| 130 127 | 
             
                    # find the keypoints
         | 
| 131 128 | 
             
                    mask = np.zeros_like(frame)
         | 
| 132 129 | 
             
                    # mask[int(0.05 * height): int(0.95 * height), int(0.05 * width): int(0.95 * width)] = 255
         | 
| 133 | 
            -
                    mask[int(0.02 * height) | 
| 130 | 
            +
                    mask[int(0.02 * height): int(0.98 * height), int(0.02 * width): int(0.98 * width)] = 255
         | 
| 134 131 | 
             
                    if detections is not None:
         | 
| 135 132 | 
             
                        for det in detections:
         | 
| 136 133 | 
             
                            tlbr = (det[:4] / self.downscale).astype(np.int_)
         | 
| 137 | 
            -
                            mask[tlbr[1] | 
| 134 | 
            +
                            mask[tlbr[1]:tlbr[3], tlbr[0]:tlbr[2]] = 0
         | 
| 138 135 |  | 
| 139 136 | 
             
                    keypoints = self.detector.detect(frame, mask)
         | 
| 140 137 |  | 
| @@ -176,14 +173,11 @@ class GMC: | |
| 176 173 | 
             
                            prevKeyPointLocation = self.prevKeyPoints[m.queryIdx].pt
         | 
| 177 174 | 
             
                            currKeyPointLocation = keypoints[m.trainIdx].pt
         | 
| 178 175 |  | 
| 179 | 
            -
                            spatialDistance = (
         | 
| 180 | 
            -
             | 
| 181 | 
            -
                                prevKeyPointLocation[1] - currKeyPointLocation[1],
         | 
| 182 | 
            -
                            )
         | 
| 176 | 
            +
                            spatialDistance = (prevKeyPointLocation[0] - currKeyPointLocation[0],
         | 
| 177 | 
            +
                                               prevKeyPointLocation[1] - currKeyPointLocation[1])
         | 
| 183 178 |  | 
| 184 | 
            -
                            if (np.abs(spatialDistance[0]) < maxSpatialDistance[0]) and  | 
| 185 | 
            -
             | 
| 186 | 
            -
                            ):
         | 
| 179 | 
            +
                            if (np.abs(spatialDistance[0]) < maxSpatialDistance[0]) and \
         | 
| 180 | 
            +
                                    (np.abs(spatialDistance[1]) < maxSpatialDistance[1]):
         | 
| 187 181 | 
             
                                spatialDistances.append(spatialDistance)
         | 
| 188 182 | 
             
                                matches.append(m)
         | 
| 189 183 |  | 
| @@ -233,7 +227,7 @@ class GMC: | |
| 233 227 | 
             
                            H[0, 2] *= self.downscale
         | 
| 234 228 | 
             
                            H[1, 2] *= self.downscale
         | 
| 235 229 | 
             
                    else:
         | 
| 236 | 
            -
                        print( | 
| 230 | 
            +
                        print('Warning: not enough matching points')
         | 
| 237 231 |  | 
| 238 232 | 
             
                    # Store to next iteration
         | 
| 239 233 | 
             
                    self.prevFrame = frame.copy()
         | 
| @@ -271,9 +265,7 @@ class GMC: | |
| 271 265 | 
             
                        return H
         | 
| 272 266 |  | 
| 273 267 | 
             
                    # find correspondences
         | 
| 274 | 
            -
                    matchedKeypoints, status, err = cv2.calcOpticalFlowPyrLK(
         | 
| 275 | 
            -
                        self.prevFrame, frame, self.prevKeyPoints, None
         | 
| 276 | 
            -
                    )
         | 
| 268 | 
            +
                    matchedKeypoints, status, err = cv2.calcOpticalFlowPyrLK(self.prevFrame, frame, self.prevKeyPoints, None)
         | 
| 277 269 |  | 
| 278 270 | 
             
                    # leave good correspondences only
         | 
| 279 271 | 
             
                    prevPoints = []
         | 
| @@ -296,7 +288,7 @@ class GMC: | |
| 296 288 | 
             
                            H[0, 2] *= self.downscale
         | 
| 297 289 | 
             
                            H[1, 2] *= self.downscale
         | 
| 298 290 | 
             
                    else:
         | 
| 299 | 
            -
                        print( | 
| 291 | 
            +
                        print('Warning: not enough matching points')
         | 
| 300 292 |  | 
| 301 293 | 
             
                    # Store to next iteration
         | 
| 302 294 | 
             
                    self.prevFrame = frame.copy()
         | 
| @@ -313,7 +305,7 @@ class GMC: | |
| 313 305 | 
             
                def applyFile(self, raw_frame, detections=None):
         | 
| 314 306 | 
             
                    line = self.gmcFile.readline()
         | 
| 315 307 | 
             
                    tokens = line.split("\t")
         | 
| 316 | 
            -
                    H = np.eye(2, 3, dtype=np. | 
| 308 | 
            +
                    H = np.eye(2, 3, dtype=np.float32)
         | 
| 317 309 | 
             
                    H[0, 0] = float(tokens[1])
         | 
| 318 310 | 
             
                    H[0, 1] = float(tokens[2])
         | 
| 319 311 | 
             
                    H[0, 2] = float(tokens[3])
         | 
| @@ -321,4 +313,4 @@ class GMC: | |
| 321 313 | 
             
                    H[1, 1] = float(tokens[5])
         | 
| 322 314 | 
             
                    H[1, 2] = float(tokens[6])
         | 
| 323 315 |  | 
| 324 | 
            -
                    return H
         | 
| 316 | 
            +
                    return H
         | 
| @@ -1,5 +1,7 @@ | |
| 1 1 | 
             
            # vim: expandtab:ts=4:sw=4
         | 
| 2 2 | 
             
            import numpy as np
         | 
| 3 | 
            +
            import scipy.linalg
         | 
| 4 | 
            +
             | 
| 3 5 |  | 
| 4 6 | 
             
            """
         | 
| 5 7 | 
             
            Table for the 0.95 quantile of the chi-square distribution with N degrees of
         | 
| @@ -24,13 +26,13 @@ class KalmanFilter(object): | |
| 24 26 |  | 
| 25 27 | 
             
                The 8-dimensional state space
         | 
| 26 28 |  | 
| 27 | 
            -
                    x, y,  | 
| 29 | 
            +
                    x, y, w, h, vx, vy, vw, vh
         | 
| 28 30 |  | 
| 29 | 
            -
                contains the bounding box center position (x, y),  | 
| 31 | 
            +
                contains the bounding box center position (x, y), width w, height h,
         | 
| 30 32 | 
             
                and their respective velocities.
         | 
| 31 33 |  | 
| 32 34 | 
             
                Object motion follows a constant velocity model. The bounding box location
         | 
| 33 | 
            -
                (x, y,  | 
| 35 | 
            +
                (x, y, w, h) is taken as direct observation of the state space (linear
         | 
| 34 36 | 
             
                observation model).
         | 
| 35 37 |  | 
| 36 38 | 
             
                """
         | 
| @@ -56,8 +58,8 @@ class KalmanFilter(object): | |
| 56 58 | 
             
                    Parameters
         | 
| 57 59 | 
             
                    ----------
         | 
| 58 60 | 
             
                    measurement : ndarray
         | 
| 59 | 
            -
                        Bounding box coordinates (x, y,  | 
| 60 | 
            -
                         | 
| 61 | 
            +
                        Bounding box coordinates (x, y, w, h) with center position (x, y),
         | 
| 62 | 
            +
                        width w, and height h.
         | 
| 61 63 |  | 
| 62 64 | 
             
                    Returns
         | 
| 63 65 | 
             
                    -------
         | 
| @@ -72,13 +74,13 @@ class KalmanFilter(object): | |
| 72 74 | 
             
                    mean = np.r_[mean_pos, mean_vel]
         | 
| 73 75 |  | 
| 74 76 | 
             
                    std = [
         | 
| 77 | 
            +
                        2 * self._std_weight_position * measurement[2],
         | 
| 75 78 | 
             
                        2 * self._std_weight_position * measurement[3],
         | 
| 79 | 
            +
                        2 * self._std_weight_position * measurement[2],
         | 
| 76 80 | 
             
                        2 * self._std_weight_position * measurement[3],
         | 
| 77 | 
            -
                         | 
| 78 | 
            -
                        2 * self._std_weight_position * measurement[3],
         | 
| 81 | 
            +
                        10 * self._std_weight_velocity * measurement[2],
         | 
| 79 82 | 
             
                        10 * self._std_weight_velocity * measurement[3],
         | 
| 80 | 
            -
                        10 * self._std_weight_velocity * measurement[ | 
| 81 | 
            -
                        1e-5,
         | 
| 83 | 
            +
                        10 * self._std_weight_velocity * measurement[2],
         | 
| 82 84 | 
             
                        10 * self._std_weight_velocity * measurement[3]]
         | 
| 83 85 | 
             
                    covariance = np.diag(np.square(std))
         | 
| 84 86 | 
             
                    return mean, covariance
         | 
| @@ -103,18 +105,18 @@ class KalmanFilter(object): | |
| 103 105 |  | 
| 104 106 | 
             
                    """
         | 
| 105 107 | 
             
                    std_pos = [
         | 
| 108 | 
            +
                        self._std_weight_position * mean[2],
         | 
| 106 109 | 
             
                        self._std_weight_position * mean[3],
         | 
| 107 | 
            -
                        self._std_weight_position * mean[ | 
| 108 | 
            -
                        1e-2,
         | 
| 110 | 
            +
                        self._std_weight_position * mean[2],
         | 
| 109 111 | 
             
                        self._std_weight_position * mean[3]]
         | 
| 110 112 | 
             
                    std_vel = [
         | 
| 113 | 
            +
                        self._std_weight_velocity * mean[2],
         | 
| 111 114 | 
             
                        self._std_weight_velocity * mean[3],
         | 
| 112 | 
            -
                        self._std_weight_velocity * mean[ | 
| 113 | 
            -
                        1e-5,
         | 
| 115 | 
            +
                        self._std_weight_velocity * mean[2],
         | 
| 114 116 | 
             
                        self._std_weight_velocity * mean[3]]
         | 
| 115 117 | 
             
                    motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
         | 
| 116 118 |  | 
| 117 | 
            -
                    mean = np.dot(self._motion_mat | 
| 119 | 
            +
                    mean = np.dot(mean, self._motion_mat.T)
         | 
| 118 120 | 
             
                    covariance = np.linalg.multi_dot((
         | 
| 119 121 | 
             
                        self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
         | 
| 120 122 |  | 
| @@ -138,9 +140,9 @@ class KalmanFilter(object): | |
| 138 140 |  | 
| 139 141 | 
             
                    """
         | 
| 140 142 | 
             
                    std = [
         | 
| 143 | 
            +
                        self._std_weight_position * mean[2],
         | 
| 141 144 | 
             
                        self._std_weight_position * mean[3],
         | 
| 142 | 
            -
                        self._std_weight_position * mean[ | 
| 143 | 
            -
                        1e-1,
         | 
| 145 | 
            +
                        self._std_weight_position * mean[2],
         | 
| 144 146 | 
             
                        self._std_weight_position * mean[3]]
         | 
| 145 147 | 
             
                    innovation_cov = np.diag(np.square(std))
         | 
| 146 148 |  | 
| @@ -149,6 +151,45 @@ class KalmanFilter(object): | |
| 149 151 | 
             
                        self._update_mat, covariance, self._update_mat.T))
         | 
| 150 152 | 
             
                    return mean, covariance + innovation_cov
         | 
| 151 153 |  | 
| 154 | 
            +
                def multi_predict(self, mean, covariance):
         | 
| 155 | 
            +
                    """Run Kalman filter prediction step (Vectorized version).
         | 
| 156 | 
            +
                    Parameters
         | 
| 157 | 
            +
                    ----------
         | 
| 158 | 
            +
                    mean : ndarray
         | 
| 159 | 
            +
                        The Nx8 dimensional mean matrix of the object states at the previous
         | 
| 160 | 
            +
                        time step.
         | 
| 161 | 
            +
                    covariance : ndarray
         | 
| 162 | 
            +
                        The Nx8x8 dimensional covariance matrics of the object states at the
         | 
| 163 | 
            +
                        previous time step.
         | 
| 164 | 
            +
                    Returns
         | 
| 165 | 
            +
                    -------
         | 
| 166 | 
            +
                    (ndarray, ndarray)
         | 
| 167 | 
            +
                        Returns the mean vector and covariance matrix of the predicted
         | 
| 168 | 
            +
                        state. Unobserved velocities are initialized to 0 mean.
         | 
| 169 | 
            +
                    """
         | 
| 170 | 
            +
                    std_pos = [
         | 
| 171 | 
            +
                        self._std_weight_position * mean[:, 2],
         | 
| 172 | 
            +
                        self._std_weight_position * mean[:, 3],
         | 
| 173 | 
            +
                        self._std_weight_position * mean[:, 2],
         | 
| 174 | 
            +
                        self._std_weight_position * mean[:, 3]]
         | 
| 175 | 
            +
                    std_vel = [
         | 
| 176 | 
            +
                        self._std_weight_velocity * mean[:, 2],
         | 
| 177 | 
            +
                        self._std_weight_velocity * mean[:, 3],
         | 
| 178 | 
            +
                        self._std_weight_velocity * mean[:, 2],
         | 
| 179 | 
            +
                        self._std_weight_velocity * mean[:, 3]]
         | 
| 180 | 
            +
                    sqr = np.square(np.r_[std_pos, std_vel]).T
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                    motion_cov = []
         | 
| 183 | 
            +
                    for i in range(len(mean)):
         | 
| 184 | 
            +
                        motion_cov.append(np.diag(sqr[i]))
         | 
| 185 | 
            +
                    motion_cov = np.asarray(motion_cov)
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    mean = np.dot(mean, self._motion_mat.T)
         | 
| 188 | 
            +
                    left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2))
         | 
| 189 | 
            +
                    covariance = np.dot(left, self._motion_mat.T) + motion_cov
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                    return mean, covariance
         | 
| 192 | 
            +
             | 
| 152 193 | 
             
                def update(self, mean, covariance, measurement):
         | 
| 153 194 | 
             
                    """Run Kalman filter correction step.
         | 
| 154 195 |  | 
| @@ -159,8 +200,8 @@ class KalmanFilter(object): | |
| 159 200 | 
             
                    covariance : ndarray
         | 
| 160 201 | 
             
                        The state's covariance matrix (8x8 dimensional).
         | 
| 161 202 | 
             
                    measurement : ndarray
         | 
| 162 | 
            -
                        The 4 dimensional measurement vector (x, y,  | 
| 163 | 
            -
                        is the center position,  | 
| 203 | 
            +
                        The 4 dimensional measurement vector (x, y, w, h), where (x, y)
         | 
| 204 | 
            +
                        is the center position, w the width, and h the height of the
         | 
| 164 205 | 
             
                        bounding box.
         | 
| 165 206 |  | 
| 166 207 | 
             
                    Returns
         | 
| @@ -169,8 +210,6 @@ class KalmanFilter(object): | |
| 169 210 | 
             
                        Returns the measurement-corrected state distribution.
         | 
| 170 211 |  | 
| 171 212 | 
             
                    """
         | 
| 172 | 
            -
                    import scipy.linalg  # pylint: disable=import-error
         | 
| 173 | 
            -
             | 
| 174 213 | 
             
                    projected_mean, projected_cov = self.project(mean, covariance)
         | 
| 175 214 |  | 
| 176 215 | 
             
                    chol_factor, lower = scipy.linalg.cho_factor(
         | 
| @@ -186,13 +225,11 @@ class KalmanFilter(object): | |
| 186 225 | 
             
                    return new_mean, new_covariance
         | 
| 187 226 |  | 
| 188 227 | 
             
                def gating_distance(self, mean, covariance, measurements,
         | 
| 189 | 
            -
                                    only_position=False):
         | 
| 228 | 
            +
                                    only_position=False, metric='maha'):
         | 
| 190 229 | 
             
                    """Compute gating distance between state distribution and measurements.
         | 
| 191 | 
            -
             | 
| 192 230 | 
             
                    A suitable distance threshold can be obtained from `chi2inv95`. If
         | 
| 193 231 | 
             
                    `only_position` is False, the chi-square distribution has 4 degrees of
         | 
| 194 232 | 
             
                    freedom, otherwise 2.
         | 
| 195 | 
            -
             | 
| 196 233 | 
             
                    Parameters
         | 
| 197 234 | 
             
                    ----------
         | 
| 198 235 | 
             
                    mean : ndarray
         | 
| @@ -206,26 +243,27 @@ class KalmanFilter(object): | |
| 206 243 | 
             
                    only_position : Optional[bool]
         | 
| 207 244 | 
             
                        If True, distance computation is done with respect to the bounding
         | 
| 208 245 | 
             
                        box center position only.
         | 
| 209 | 
            -
             | 
| 210 246 | 
             
                    Returns
         | 
| 211 247 | 
             
                    -------
         | 
| 212 248 | 
             
                    ndarray
         | 
| 213 249 | 
             
                        Returns an array of length N, where the i-th element contains the
         | 
| 214 250 | 
             
                        squared Mahalanobis distance between (mean, covariance) and
         | 
| 215 251 | 
             
                        `measurements[i]`.
         | 
| 216 | 
            -
             | 
| 217 252 | 
             
                    """
         | 
| 218 | 
            -
                    import scipy.linalg  # pylint: disable=import-error
         | 
| 219 | 
            -
             | 
| 220 253 | 
             
                    mean, covariance = self.project(mean, covariance)
         | 
| 221 254 | 
             
                    if only_position:
         | 
| 222 255 | 
             
                        mean, covariance = mean[:2], covariance[:2, :2]
         | 
| 223 256 | 
             
                        measurements = measurements[:, :2]
         | 
| 224 257 |  | 
| 225 | 
            -
                    cholesky_factor = np.linalg.cholesky(covariance)
         | 
| 226 258 | 
             
                    d = measurements - mean
         | 
| 227 | 
            -
                     | 
| 228 | 
            -
                         | 
| 229 | 
            -
             | 
| 230 | 
            -
             | 
| 231 | 
            -
             | 
| 259 | 
            +
                    if metric == 'gaussian':
         | 
| 260 | 
            +
                        return np.sum(d * d, axis=1)
         | 
| 261 | 
            +
                    elif metric == 'maha':
         | 
| 262 | 
            +
                        cholesky_factor = np.linalg.cholesky(covariance)
         | 
| 263 | 
            +
                        z = scipy.linalg.solve_triangular(
         | 
| 264 | 
            +
                            cholesky_factor, d.T, lower=True, check_finite=False,
         | 
| 265 | 
            +
                            overwrite_b=True)
         | 
| 266 | 
            +
                        squared_maha = np.sum(z * z, axis=0)
         | 
| 267 | 
            +
                        return squared_maha
         | 
| 268 | 
            +
                    else:
         | 
| 269 | 
            +
                        raise ValueError('invalid distance metric')
         |