msight-vision 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/__init__.py +0 -0
- cli/launch_2d_viewer.py +10 -0
- cli/launch_custom_fuser.py +23 -0
- cli/launch_finite_difference_state_estimator.py +22 -0
- cli/launch_road_user_list_viewer.py +23 -0
- cli/launch_sort_tracker.py +24 -0
- cli/launch_yolo_onestage_detection.py +22 -0
- msight_vision/__init__.py +8 -0
- msight_vision/base.py +99 -0
- msight_vision/detector_yolo.py +87 -0
- msight_vision/fuser.py +325 -0
- msight_vision/localizer.py +32 -0
- msight_vision/msight_core/__init__.py +6 -0
- msight_vision/msight_core/detection.py +103 -0
- msight_vision/msight_core/fusion.py +64 -0
- msight_vision/msight_core/state_estimation.py +38 -0
- msight_vision/msight_core/tracking.py +31 -0
- msight_vision/msight_core/viewer.py +55 -0
- msight_vision/msight_core/warper.py +98 -0
- msight_vision/state_estimator.py +121 -0
- msight_vision/tracker.py +525 -0
- msight_vision/utils/__init__.py +3 -0
- msight_vision/utils/data.py +80 -0
- msight_vision/utils/typing.py +18 -0
- msight_vision/utils/vis.py +17 -0
- msight_vision/warper.py +89 -0
- msight_vision-0.1.0.dist-info/METADATA +28 -0
- msight_vision-0.1.0.dist-info/RECORD +31 -0
- msight_vision-0.1.0.dist-info/WHEEL +5 -0
- msight_vision-0.1.0.dist-info/entry_points.txt +7 -0
- msight_vision-0.1.0.dist-info/top_level.txt +2 -0
cli/__init__.py
ADDED
|
File without changes
|
cli/launch_2d_viewer.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
from msight_vision.msight_core import DetectionResults2DViewerNode
|
|
2
|
+
from msight_core.utils import get_node_config_from_args, get_default_arg_parser
|
|
3
|
+
|
|
4
|
+
def main():
|
|
5
|
+
parser = get_default_arg_parser(description="Launch Detection Results 2D Viewer Node", node_class=DetectionResults2DViewerNode)
|
|
6
|
+
args = parser.parse_args()
|
|
7
|
+
detection_node = DetectionResults2DViewerNode(
|
|
8
|
+
configs=get_node_config_from_args(args)
|
|
9
|
+
)
|
|
10
|
+
detection_node.spin()
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from msight_vision.msight_core import FuserNode
|
|
2
|
+
from msight_core.utils import get_node_config_from_args, get_default_arg_parser
|
|
3
|
+
import time
|
|
4
|
+
|
|
5
|
+
def main():
|
|
6
|
+
parser = get_default_arg_parser(description="Launch Fuser Node", node_class=FuserNode)
|
|
7
|
+
parser.add_argument("--fusion-config", "-fc", type=str, required=True, help="Path to the configuration file")
|
|
8
|
+
parser.add_argument("--wait", "-w", type=int, default=0, help="Wait time before starting the node (in seconds)")
|
|
9
|
+
args = parser.parse_args()
|
|
10
|
+
if args.wait > 0:
|
|
11
|
+
print(f"Waiting for {args.wait} seconds before starting the node...")
|
|
12
|
+
time.sleep(args.wait)
|
|
13
|
+
|
|
14
|
+
configs = get_node_config_from_args(args)
|
|
15
|
+
|
|
16
|
+
detection_node = FuserNode(
|
|
17
|
+
configs,
|
|
18
|
+
args.fusion_config,
|
|
19
|
+
)
|
|
20
|
+
detection_node.spin()
|
|
21
|
+
|
|
22
|
+
if __name__ == "__main__":
|
|
23
|
+
main()
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from msight_vision.msight_core import FiniteDifferenceStateEstimatorNode
|
|
2
|
+
from msight_core.utils import get_node_config_from_args, get_default_arg_parser
|
|
3
|
+
import time
|
|
4
|
+
|
|
5
|
+
def main():
|
|
6
|
+
parser = get_default_arg_parser(description="Launch Finite Difference State Estimator Node", node_class=FiniteDifferenceStateEstimatorNode)
|
|
7
|
+
parser.add_argument("--estimator-configs", "-ec", type=str, required=True, help="Path to the configuration file")
|
|
8
|
+
parser.add_argument("--wait", "-w", type=int, default=0, help="Wait time before starting the node (in seconds)")
|
|
9
|
+
args = parser.parse_args()
|
|
10
|
+
|
|
11
|
+
if args.wait > 0:
|
|
12
|
+
print(f"Waiting for {args.wait} seconds before starting the node...")
|
|
13
|
+
time.sleep(args.wait)
|
|
14
|
+
configs = get_node_config_from_args(args)
|
|
15
|
+
detection_node = FiniteDifferenceStateEstimatorNode(
|
|
16
|
+
configs,
|
|
17
|
+
args.estimator_configs,
|
|
18
|
+
)
|
|
19
|
+
detection_node.spin()
|
|
20
|
+
|
|
21
|
+
if __name__ == "__main__":
|
|
22
|
+
main()
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from msight_vision.msight_core import RoadUserListViewerNode
|
|
2
|
+
from msight_core.utils import get_node_config_from_args, get_default_arg_parser
|
|
3
|
+
|
|
4
|
+
def main():
|
|
5
|
+
parser = get_default_arg_parser(description="Launch Road User List Viewer Node", node_class=RoadUserListViewerNode)
|
|
6
|
+
parser.add_argument("--basemap", type=str, required=True, help="Path to the basemap image")
|
|
7
|
+
parser.add_argument("--show-trajectory", action='store_true', help="Flag to show trajectory")
|
|
8
|
+
parser.add_argument("--show-heading", action='store_true', help="Flag to show heading")
|
|
9
|
+
args = parser.parse_args()
|
|
10
|
+
configs = get_node_config_from_args(args)
|
|
11
|
+
# sub_topic = get_topic(redis_client, "example_fused_results")
|
|
12
|
+
|
|
13
|
+
detection_node = RoadUserListViewerNode(
|
|
14
|
+
configs,
|
|
15
|
+
args.basemap,
|
|
16
|
+
args.show_trajectory,
|
|
17
|
+
# False,
|
|
18
|
+
args.show_heading,
|
|
19
|
+
)
|
|
20
|
+
detection_node.spin()
|
|
21
|
+
|
|
22
|
+
if __name__ == "__main__":
|
|
23
|
+
main()
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from msight_vision.msight_core import SortTrackerNode
|
|
2
|
+
from msight_core.utils import get_node_config_from_args, get_default_arg_parser
|
|
3
|
+
import time
|
|
4
|
+
|
|
5
|
+
def main():
|
|
6
|
+
parser = get_default_arg_parser(description="Launch SORT Tracker Node", node_class=SortTrackerNode)
|
|
7
|
+
parser.add_argument("--tracking-configs", "-tc", type=str, required=True,
|
|
8
|
+
help="Path to the configuration file")
|
|
9
|
+
parser.add_argument("--wait", "-w", type=int, default=0, help="Wait time before starting the node (in seconds)")
|
|
10
|
+
args = parser.parse_args()
|
|
11
|
+
|
|
12
|
+
if args.wait > 0:
|
|
13
|
+
print(f"Waiting for {args.wait} seconds before starting the node...")
|
|
14
|
+
time.sleep(args.wait)
|
|
15
|
+
|
|
16
|
+
configs = get_node_config_from_args(args)
|
|
17
|
+
detection_node = SortTrackerNode(
|
|
18
|
+
configs,
|
|
19
|
+
args.tracking_configs,
|
|
20
|
+
)
|
|
21
|
+
detection_node.spin()
|
|
22
|
+
|
|
23
|
+
if __name__ == "__main__":
|
|
24
|
+
main()
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from msight_vision.msight_core import YoloOneStageDetectionNode
|
|
2
|
+
from msight_core.utils import get_node_config_from_args, get_default_arg_parser
|
|
3
|
+
import time
|
|
4
|
+
|
|
5
|
+
def main():
|
|
6
|
+
parser = get_default_arg_parser(description="Launch YOLO One-Stage Detection Node", node_class=YoloOneStageDetectionNode)
|
|
7
|
+
parser.add_argument("--det-configs", "-dc", type=str, required=True, help="Path to the configuration file")
|
|
8
|
+
parser.add_argument("--wait", "-w", type=int, default=0, help="Wait time before starting the node (in seconds)")
|
|
9
|
+
args = parser.parse_args()
|
|
10
|
+
|
|
11
|
+
if args.wait > 0:
|
|
12
|
+
print(f"Waiting for {args.wait} seconds before starting the node...")
|
|
13
|
+
time.sleep(args.wait)
|
|
14
|
+
configs = get_node_config_from_args(args)
|
|
15
|
+
detection_node = YoloOneStageDetectionNode(
|
|
16
|
+
configs,
|
|
17
|
+
args.det_configs,
|
|
18
|
+
)
|
|
19
|
+
detection_node.spin()
|
|
20
|
+
|
|
21
|
+
if __name__ == "__main__":
|
|
22
|
+
main()
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
from importlib.metadata import version
|
|
2
|
+
from .detector_yolo import YoloDetector, Yolo26Detector
|
|
3
|
+
from .localizer import HashLocalizer
|
|
4
|
+
from .tracker import SortTracker
|
|
5
|
+
from .warper import ClassicWarper, ClassicWarperWithExternalUpdate
|
|
6
|
+
from .fuser import FuserBase
|
|
7
|
+
from .state_estimator import FiniteDifferenceStateEstimator
|
|
8
|
+
__version__ = version("msight_vision")
|
msight_vision/base.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
from msight_base import DetectionResultBase, DetectedObjectBase, RoadUserPoint
|
|
2
|
+
import numpy as np
|
|
3
|
+
from typing import List, Dict
|
|
4
|
+
|
|
5
|
+
class DetectedObject2D(DetectedObjectBase):
|
|
6
|
+
"""Detected object for 2D images."""
|
|
7
|
+
|
|
8
|
+
def __init__(self, box: list, class_id: int, score: float, pixel_bottom_center: List[float], obj_id: str = None, lat: float = None, lon: float = None, x: float = None, y: float = None):
|
|
9
|
+
"""
|
|
10
|
+
Initialize the detected object.
|
|
11
|
+
:param box: bounding box coordinates (x1, y1, x2, y2)
|
|
12
|
+
:param class_id: class ID of the detected object
|
|
13
|
+
:param score: confidence score of the detection
|
|
14
|
+
:param obj_id: unique ID of the detected object (optional)
|
|
15
|
+
:param lat: latitude of the detected object (optional)
|
|
16
|
+
:param lon: longitude of the detected object (optional)
|
|
17
|
+
:param x: x coordinate in the coordination of interest like utm of the detected object (optional)
|
|
18
|
+
:param y: y coordinate in the coordination of intererst like utm of the detected object (optional)
|
|
19
|
+
"""
|
|
20
|
+
super().__init__()
|
|
21
|
+
self.box = box
|
|
22
|
+
self.class_id = class_id
|
|
23
|
+
self.score = score
|
|
24
|
+
self.pixel_bottom_center = pixel_bottom_center
|
|
25
|
+
self.obj_id = obj_id
|
|
26
|
+
self.lat = lat
|
|
27
|
+
self.lon = lon
|
|
28
|
+
self.x = x
|
|
29
|
+
self.y = y
|
|
30
|
+
|
|
31
|
+
def to_dict(self):
|
|
32
|
+
"""
|
|
33
|
+
Convert the detected object to a dictionary.
|
|
34
|
+
:return: dictionary representation of the detected object
|
|
35
|
+
"""
|
|
36
|
+
return {
|
|
37
|
+
"box": self.box,
|
|
38
|
+
"class_id": self.class_id,
|
|
39
|
+
"score": self.score,
|
|
40
|
+
"obj_id": self.obj_id,
|
|
41
|
+
"lat": self.lat,
|
|
42
|
+
"lon": self.lon,
|
|
43
|
+
"x": self.x,
|
|
44
|
+
"y": self.y,
|
|
45
|
+
"pixel_bottom_center": self.pixel_bottom_center,
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
@staticmethod
|
|
49
|
+
def from_dict(data: dict):
|
|
50
|
+
"""
|
|
51
|
+
Create a DetectedObject2D instance from a dictionary.
|
|
52
|
+
:param data: dictionary representation of the detected object
|
|
53
|
+
:return: DetectedObject2D instance
|
|
54
|
+
"""
|
|
55
|
+
return DetectedObject2D(
|
|
56
|
+
box=data["box"],
|
|
57
|
+
class_id=data["class_id"],
|
|
58
|
+
score=data["score"],
|
|
59
|
+
obj_id=data.get("obj_id"),
|
|
60
|
+
lat=data.get("lat") or None,
|
|
61
|
+
lon=data.get("lon") or None,
|
|
62
|
+
x=data.get("x") or None,
|
|
63
|
+
y=data.get("y") or None,
|
|
64
|
+
pixel_bottom_center=data.get("pixel_bottom_center") or None,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
def __repr__(self):
|
|
68
|
+
return f"DetectedObject2D(box={self.box}, class_id={self.class_id}, score={self.score}, obj_id={self.obj_id}, lat={self.lat}, lon={self.lon}, x={self.x}, y={self.y})"
|
|
69
|
+
|
|
70
|
+
class DetectionResult2D(DetectionResultBase):
|
|
71
|
+
"""Detection result for 2D images."""
|
|
72
|
+
|
|
73
|
+
def __init__(self, object_list: List[DetectedObject2D], timestamp: int, sensor_type: str):
|
|
74
|
+
"""
|
|
75
|
+
Initialize the detection result.
|
|
76
|
+
:param detected_objects: list of detected objects
|
|
77
|
+
"""
|
|
78
|
+
super().__init__(object_list, timestamp, sensor_type)
|
|
79
|
+
|
|
80
|
+
class ImageDetector2DBase:
|
|
81
|
+
def detect(self, image: np.ndarray) -> DetectionResult2D:
|
|
82
|
+
"""
|
|
83
|
+
Detector base, a detector detects objects in the image.
|
|
84
|
+
:param image: input image
|
|
85
|
+
:return: list of detected objects
|
|
86
|
+
"""
|
|
87
|
+
raise NotImplementedError("detect method not implemented")
|
|
88
|
+
|
|
89
|
+
class TrackerBase:
|
|
90
|
+
def __init__(self):
|
|
91
|
+
pass
|
|
92
|
+
|
|
93
|
+
def track(self, list) ->Dict[str, RoadUserPoint]:
|
|
94
|
+
"""
|
|
95
|
+
Track the detected objects in the image.
|
|
96
|
+
:param detection_result: DetectionResult2D instance
|
|
97
|
+
:return: updated DetectionResult2D instance with tracking information
|
|
98
|
+
"""
|
|
99
|
+
raise NotImplementedError("track method not implemented")
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
from numpy import ndarray
|
|
2
|
+
from msight_vision.base import DetectionResult2D, DetectedObject2D
|
|
3
|
+
from .base import ImageDetector2DBase
|
|
4
|
+
from ultralytics import YOLO
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
class YoloDetector(ImageDetector2DBase):
|
|
8
|
+
"""YOLOv5 detector for 2D images."""
|
|
9
|
+
|
|
10
|
+
def __init__(self, model_path: Path, device: str = "cpu", confthre: float = 0.25, nmsthre: float = 0.45, fp16: bool = False, class_agnostic_nms: bool = False):
|
|
11
|
+
"""
|
|
12
|
+
Initialize the YOLO detector.
|
|
13
|
+
:param model_path: path to the YOLO model
|
|
14
|
+
:param device: device to run the model on (e.g., 'cpu', 'cuda')
|
|
15
|
+
"""
|
|
16
|
+
super().__init__()
|
|
17
|
+
self.model = YOLO(str(model_path))
|
|
18
|
+
self.device = device
|
|
19
|
+
self.confthre = confthre
|
|
20
|
+
self.nmsthre = nmsthre
|
|
21
|
+
self.fp16 = fp16
|
|
22
|
+
self.class_agnostic_nms = class_agnostic_nms
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def convert_yolo_result_to_detection_result(self, yolo_output_results, timestamp, sensor_type):
|
|
26
|
+
"""
|
|
27
|
+
Convert YOLO output results to DetectionResult2D.
|
|
28
|
+
:param yolo_output_results: YOLO output results
|
|
29
|
+
:param timestamp: timestamp of the image
|
|
30
|
+
:param sensor_type: type of the sensor
|
|
31
|
+
:return: DetectionResult2D instance
|
|
32
|
+
"""
|
|
33
|
+
# Convert YOLO output to DetectionResult2D
|
|
34
|
+
bboxes = yolo_output_results[0].boxes.xyxy.cpu().numpy()
|
|
35
|
+
confs = yolo_output_results[0].boxes.conf.cpu().numpy()
|
|
36
|
+
class_ids = yolo_output_results[0].boxes.cls.cpu().numpy()
|
|
37
|
+
|
|
38
|
+
detected_objects = []
|
|
39
|
+
for i in range(len(bboxes)):
|
|
40
|
+
box = bboxes[i]
|
|
41
|
+
class_id = int(class_ids[i])
|
|
42
|
+
score = float(confs[i])
|
|
43
|
+
# calculate the center coordinates of the bounding box
|
|
44
|
+
center_x = float((box[0] + box[2]) / 2)
|
|
45
|
+
center_y = float((box[1] + box[3]) / 2)
|
|
46
|
+
# print(class_id)
|
|
47
|
+
detected_object = DetectedObject2D(
|
|
48
|
+
box=[float(box[0]), float(box[1]), float(box[2]), float(box[3])],
|
|
49
|
+
class_id=class_id,
|
|
50
|
+
score=score,
|
|
51
|
+
pixel_bottom_center=[center_x, center_y],
|
|
52
|
+
)
|
|
53
|
+
detected_objects.append(detected_object)
|
|
54
|
+
|
|
55
|
+
detection_result = DetectionResult2D(
|
|
56
|
+
detected_objects,
|
|
57
|
+
timestamp,
|
|
58
|
+
sensor_type,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
return detection_result
|
|
62
|
+
|
|
63
|
+
def detect(self, image: ndarray, timestamp, sensor_type) -> DetectionResult2D:
|
|
64
|
+
yolo_output_results = self.model(image, device=self.device, conf=self.confthre, iou=self.nmsthre, half=self.fp16, verbose=False, agnostic_nms=self.class_agnostic_nms)
|
|
65
|
+
## Convert results to DetectionResult2D
|
|
66
|
+
detection_result = self.convert_yolo_result_to_detection_result(
|
|
67
|
+
yolo_output_results,
|
|
68
|
+
timestamp,
|
|
69
|
+
sensor_type,
|
|
70
|
+
)
|
|
71
|
+
return detection_result
|
|
72
|
+
|
|
73
|
+
class Yolo26Detector(YoloDetector):
|
|
74
|
+
"""YOLOv2.6 detector for 2D images."""
|
|
75
|
+
def __init__(self, model_path: Path, device: str = "cpu", confthre: float = 0.25, nmsthre: float = 0.45, fp16: bool = False, class_agnostic_nms: bool = False, end2end: bool = False):
|
|
76
|
+
super().__init__(model_path, device, confthre, nmsthre, fp16, class_agnostic_nms)
|
|
77
|
+
|
|
78
|
+
self.end2end = end2end
|
|
79
|
+
def detect(self, image: ndarray, timestamp, sensor_type) -> DetectionResult2D:
|
|
80
|
+
yolo_output_results = self.model(image, device=self.device, conf=self.confthre, iou=self.nmsthre, half=self.fp16, verbose=False, agnostic_nms=self.class_agnostic_nms, end2end=self.end2end)
|
|
81
|
+
## Convert results to DetectionResult2D
|
|
82
|
+
detection_result = self.convert_yolo_result_to_detection_result(
|
|
83
|
+
yolo_output_results,
|
|
84
|
+
timestamp,
|
|
85
|
+
sensor_type,
|
|
86
|
+
)
|
|
87
|
+
return detection_result
|
msight_vision/fuser.py
ADDED
|
@@ -0,0 +1,325 @@
|
|
|
1
|
+
from typing import Dict, List, Tuple
|
|
2
|
+
from .base import DetectionResult2D
|
|
3
|
+
from geopy.distance import geodesic
|
|
4
|
+
from .utils import detection_to_roaduser_point
|
|
5
|
+
from msight_base import RoadUserPoint
|
|
6
|
+
from scipy.optimize import linear_sum_assignment
|
|
7
|
+
from shapely.geometry import Point, Polygon
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
class FuserBase:
|
|
11
|
+
def fuse(self, results: Dict[str, DetectionResult2D]) -> List[RoadUserPoint]:
|
|
12
|
+
"""
|
|
13
|
+
Fuses the data from different sources into a single output.
|
|
14
|
+
:param data: The input data to be fused.
|
|
15
|
+
:return: The fused output.
|
|
16
|
+
"""
|
|
17
|
+
raise NotImplementedError("FuserBase is an abstract class and cannot be instantiated directly.")
|
|
18
|
+
|
|
19
|
+
## This is a simple example of a fuser that combines the results from different cameras, which works at the roundabout of State and Ellsworth in Smart Intersection Project.
|
|
20
|
+
class StateEllsworthFuser(FuserBase):
|
|
21
|
+
'''
|
|
22
|
+
This is a simple example of a fuser that combines the results from different cameras, which works at the roundabouot of State and Ellsworth.
|
|
23
|
+
'''
|
|
24
|
+
def __init__(self):
|
|
25
|
+
self.lat1 = 42.229379
|
|
26
|
+
self.lon1 = -83.739003
|
|
27
|
+
self.lat2 = 42.229444
|
|
28
|
+
self.lon2 = -83.739013
|
|
29
|
+
|
|
30
|
+
def fuse(self, detection_buffer: Dict[str, DetectionResult2D]) -> List[RoadUserPoint]:
|
|
31
|
+
fused_vehicle_list = []
|
|
32
|
+
|
|
33
|
+
vehicle_list = detection_buffer['gs_State_Ellsworth_NW'].object_list
|
|
34
|
+
for v in vehicle_list: # cam_ne
|
|
35
|
+
if v.lat > self.lat1 and v.lon > self.lon1:
|
|
36
|
+
fused_vehicle_list.append(detection_to_roaduser_point(v, 'gs_State_Ellsworth_NW'))
|
|
37
|
+
|
|
38
|
+
vehicle_list = detection_buffer['gs_State_Ellsworth_NE'].object_list
|
|
39
|
+
for v in vehicle_list: # cam_nw
|
|
40
|
+
if v.lat > self.lat2 and v.lon < self.lon2:
|
|
41
|
+
fused_vehicle_list.append(detection_to_roaduser_point(v, 'gs_State_Ellsworth_NE'))
|
|
42
|
+
|
|
43
|
+
vehicle_list = detection_buffer['gs_State_Ellsworth_SE'].object_list
|
|
44
|
+
for v in vehicle_list: # cam_se
|
|
45
|
+
if v.lat < self.lat1 and v.lon > self.lon1:
|
|
46
|
+
fused_vehicle_list.append(detection_to_roaduser_point(v, 'gs_State_Ellsworth_SE'))
|
|
47
|
+
|
|
48
|
+
vehicle_list = detection_buffer['gs_State_Ellsworth_SW'].object_list
|
|
49
|
+
for v in vehicle_list: # cam_sw
|
|
50
|
+
if v.lat < self.lat2 and v.lon < self.lon2:
|
|
51
|
+
fused_vehicle_list.append(detection_to_roaduser_point(v, 'gs_State_Ellsworth_SW'))
|
|
52
|
+
return fused_vehicle_list
|
|
53
|
+
|
|
54
|
+
class HungarianFuser(FuserBase):
|
|
55
|
+
"""
|
|
56
|
+
A fuser that matches detections from multiple sensors based on spatial proximity
|
|
57
|
+
using Hungarian algorithm and fuses their locations using weighted averaging.
|
|
58
|
+
"""
|
|
59
|
+
def __init__(self, coverage_zones: dict, sensor_locations: dict = None, distance_threshold: float = 5.0):
|
|
60
|
+
"""
|
|
61
|
+
Initialize the HungarianFuser.
|
|
62
|
+
:param coverage_zones: dict mapping sensor_id to a polygon defining the sensor's coverage zone.
|
|
63
|
+
Each polygon is a list of (lat, lon) tuples forming a closed polygon.
|
|
64
|
+
Example: [(lat1, lon1), (lat2, lon2), (lat3, lon3), ...]
|
|
65
|
+
:param sensor_locations: dict mapping sensor_id to (lat, lon) tuple of the sensor's location.
|
|
66
|
+
If provided, weights are computed as 1/distance_to_sensor^2.
|
|
67
|
+
If not provided, bounding box area is used as weight.
|
|
68
|
+
:param distance_threshold: maximum distance (in meters) to consider two detections as the same object.
|
|
69
|
+
"""
|
|
70
|
+
self.coverage_zones = coverage_zones
|
|
71
|
+
self.sensor_locations = sensor_locations
|
|
72
|
+
self.distance_threshold = distance_threshold
|
|
73
|
+
self.sensor_list = list(coverage_zones.keys())
|
|
74
|
+
|
|
75
|
+
# Pre-compute Shapely Polygon objects for efficient point-in-polygon checks
|
|
76
|
+
self._coverage_polygons = {
|
|
77
|
+
sensor_id: Polygon(polygon) if polygon else None
|
|
78
|
+
for sensor_id, polygon in coverage_zones.items()
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
def _is_in_coverage(self, detected_object, sensor_id: str) -> bool:
|
|
82
|
+
"""
|
|
83
|
+
Check if a detected object is within the sensor's coverage zone.
|
|
84
|
+
:param detected_object: DetectedObject2D instance
|
|
85
|
+
:param sensor_id: sensor identifier
|
|
86
|
+
:return: True if in coverage zone, False otherwise
|
|
87
|
+
"""
|
|
88
|
+
polygon = self._coverage_polygons.get(sensor_id)
|
|
89
|
+
if polygon is None:
|
|
90
|
+
return True # No coverage filter defined, include all
|
|
91
|
+
point = Point(detected_object.lat, detected_object.lon)
|
|
92
|
+
return polygon.contains(point)
|
|
93
|
+
|
|
94
|
+
def _compute_weight(self, detected_object, sensor_id: str) -> float:
|
|
95
|
+
"""
|
|
96
|
+
Compute the weight for a detected object.
|
|
97
|
+
If sensor location is available, use 1/distance_to_sensor^2.
|
|
98
|
+
Otherwise, use the bounding box area.
|
|
99
|
+
:param detected_object: DetectedObject2D instance
|
|
100
|
+
:param sensor_id: sensor identifier
|
|
101
|
+
:return: weight value
|
|
102
|
+
"""
|
|
103
|
+
if self.sensor_locations is not None and sensor_id in self.sensor_locations:
|
|
104
|
+
sensor_lat, sensor_lon = self.sensor_locations[sensor_id]
|
|
105
|
+
# Compute geodesic distance to sensor in meters
|
|
106
|
+
dist = geodesic((detected_object.lat, detected_object.lon), (sensor_lat, sensor_lon)).meters
|
|
107
|
+
dist_sq = dist ** 2
|
|
108
|
+
if dist_sq < 1e-10:
|
|
109
|
+
dist_sq = 1e-10 # Avoid division by zero
|
|
110
|
+
return 1.0 / dist_sq
|
|
111
|
+
else:
|
|
112
|
+
# Use bounding box area as weight
|
|
113
|
+
# box is [x1, y1, x2, y2]
|
|
114
|
+
box = detected_object.box
|
|
115
|
+
width = box[2] - box[0]
|
|
116
|
+
height = box[3] - box[1]
|
|
117
|
+
area = width * height
|
|
118
|
+
return area if area > 0 else 1.0
|
|
119
|
+
|
|
120
|
+
def _compute_distance_to_group(self, detected_object, group: dict) -> float:
|
|
121
|
+
"""
|
|
122
|
+
Compute the geodesic distance between a detected object and a group's weighted location.
|
|
123
|
+
:param detected_object: DetectedObject2D instance
|
|
124
|
+
:param group: group dict containing 'weighted_lat' and 'weighted_lon'
|
|
125
|
+
:return: distance in meters
|
|
126
|
+
"""
|
|
127
|
+
return geodesic(
|
|
128
|
+
(detected_object.lat, detected_object.lon),
|
|
129
|
+
(group['weighted_lat'], group['weighted_lon'])
|
|
130
|
+
).meters
|
|
131
|
+
|
|
132
|
+
def _filter_detections_by_sensor(self, detection_buffer: Dict[str, DetectionResult2D]) -> Dict[str, List]:
|
|
133
|
+
"""
|
|
134
|
+
Filter detections by coverage zone and organize by sensor.
|
|
135
|
+
:param detection_buffer: dict mapping sensor_id to DetectionResult2D
|
|
136
|
+
:return: dict mapping sensor_id to list of valid DetectedObject2D instances
|
|
137
|
+
"""
|
|
138
|
+
detections_by_sensor = {}
|
|
139
|
+
for sensor_id in self.sensor_list:
|
|
140
|
+
if sensor_id not in detection_buffer:
|
|
141
|
+
continue
|
|
142
|
+
detection_result = detection_buffer[sensor_id]
|
|
143
|
+
valid_detections = []
|
|
144
|
+
for detected_object in detection_result.object_list:
|
|
145
|
+
# Skip objects without valid lat/lon
|
|
146
|
+
if detected_object.lat is None or detected_object.lon is None:
|
|
147
|
+
continue
|
|
148
|
+
# Filter by coverage zone
|
|
149
|
+
if self._is_in_coverage(detected_object, sensor_id):
|
|
150
|
+
valid_detections.append(detected_object)
|
|
151
|
+
if valid_detections:
|
|
152
|
+
detections_by_sensor[sensor_id] = valid_detections
|
|
153
|
+
return detections_by_sensor
|
|
154
|
+
|
|
155
|
+
def _create_group_from_detection(self, detected_object, sensor_id: str) -> dict:
|
|
156
|
+
"""
|
|
157
|
+
Create a new group from a single detection.
|
|
158
|
+
:param detected_object: DetectedObject2D instance
|
|
159
|
+
:param sensor_id: sensor identifier
|
|
160
|
+
:return: group dict
|
|
161
|
+
"""
|
|
162
|
+
weight = self._compute_weight(detected_object, sensor_id)
|
|
163
|
+
return {
|
|
164
|
+
'weighted_lat': detected_object.lat,
|
|
165
|
+
'weighted_lon': detected_object.lon,
|
|
166
|
+
'total_weight': weight,
|
|
167
|
+
'weighted_lat_sum': detected_object.lat * weight,
|
|
168
|
+
'weighted_lon_sum': detected_object.lon * weight,
|
|
169
|
+
'max_confidence': detected_object.score,
|
|
170
|
+
'class_id_counts': {detected_object.class_id: 1},
|
|
171
|
+
'sensor_data': {sensor_id: detected_object},
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
def _add_detection_to_group(self, group: dict, detected_object, sensor_id: str) -> None:
|
|
175
|
+
"""
|
|
176
|
+
Add a detection to an existing group and update weighted location.
|
|
177
|
+
:param group: group dict to update
|
|
178
|
+
:param detected_object: DetectedObject2D instance
|
|
179
|
+
:param sensor_id: sensor identifier
|
|
180
|
+
"""
|
|
181
|
+
weight = self._compute_weight(detected_object, sensor_id)
|
|
182
|
+
|
|
183
|
+
# Update weighted sums
|
|
184
|
+
group['weighted_lat_sum'] += detected_object.lat * weight
|
|
185
|
+
group['weighted_lon_sum'] += detected_object.lon * weight
|
|
186
|
+
group['total_weight'] += weight
|
|
187
|
+
|
|
188
|
+
# Update weighted location
|
|
189
|
+
group['weighted_lat'] = group['weighted_lat_sum'] / group['total_weight']
|
|
190
|
+
group['weighted_lon'] = group['weighted_lon_sum'] / group['total_weight']
|
|
191
|
+
|
|
192
|
+
# Update confidence
|
|
193
|
+
if detected_object.score > group['max_confidence']:
|
|
194
|
+
group['max_confidence'] = detected_object.score
|
|
195
|
+
|
|
196
|
+
# Update class_id counts
|
|
197
|
+
class_id = detected_object.class_id
|
|
198
|
+
group['class_id_counts'][class_id] = group['class_id_counts'].get(class_id, 0) + 1
|
|
199
|
+
|
|
200
|
+
# Store sensor data (keep the detected object directly)
|
|
201
|
+
group['sensor_data'][sensor_id] = detected_object
|
|
202
|
+
|
|
203
|
+
def _hungarian_match(self, groups: List[dict], detections: List, sensor_id: str) -> Tuple[List[dict], List]:
|
|
204
|
+
"""
|
|
205
|
+
Use Hungarian algorithm to match detections to existing groups.
|
|
206
|
+
:param groups: list of existing group dicts
|
|
207
|
+
:param detections: list of DetectedObject2D instances from a single sensor
|
|
208
|
+
:param sensor_id: sensor identifier
|
|
209
|
+
:return: (updated groups, unmatched detections)
|
|
210
|
+
"""
|
|
211
|
+
if not groups or not detections:
|
|
212
|
+
return groups, detections
|
|
213
|
+
|
|
214
|
+
n_groups = len(groups)
|
|
215
|
+
n_detections = len(detections)
|
|
216
|
+
|
|
217
|
+
# Build cost matrix (distance between each detection and each group)
|
|
218
|
+
cost_matrix = np.zeros((n_detections, n_groups))
|
|
219
|
+
for i, det in enumerate(detections):
|
|
220
|
+
for j, group in enumerate(groups):
|
|
221
|
+
dist = self._compute_distance_to_group(det, group)
|
|
222
|
+
cost_matrix[i, j] = dist
|
|
223
|
+
|
|
224
|
+
# Apply Hungarian algorithm
|
|
225
|
+
row_ind, col_ind = linear_sum_assignment(cost_matrix)
|
|
226
|
+
|
|
227
|
+
matched_detection_indices = set()
|
|
228
|
+
matched_group_indices = set()
|
|
229
|
+
|
|
230
|
+
# Process matches
|
|
231
|
+
for det_idx, group_idx in zip(row_ind, col_ind):
|
|
232
|
+
dist = cost_matrix[det_idx, group_idx]
|
|
233
|
+
if dist <= self.distance_threshold:
|
|
234
|
+
# Valid match - add detection to group
|
|
235
|
+
self._add_detection_to_group(groups[group_idx], detections[det_idx], sensor_id)
|
|
236
|
+
matched_detection_indices.add(det_idx)
|
|
237
|
+
matched_group_indices.add(group_idx)
|
|
238
|
+
|
|
239
|
+
# Collect unmatched detections
|
|
240
|
+
unmatched_detections = [
|
|
241
|
+
detections[i] for i in range(n_detections) if i not in matched_detection_indices
|
|
242
|
+
]
|
|
243
|
+
|
|
244
|
+
return groups, unmatched_detections
|
|
245
|
+
|
|
246
|
+
def _group_to_road_user_point(self, group: dict) -> RoadUserPoint:
|
|
247
|
+
"""
|
|
248
|
+
Convert a group to a RoadUserPoint.
|
|
249
|
+
:param group: group dict
|
|
250
|
+
:return: RoadUserPoint instance
|
|
251
|
+
"""
|
|
252
|
+
# Determine the most common class_id
|
|
253
|
+
most_common_class = max(group['class_id_counts'], key=group['class_id_counts'].get)
|
|
254
|
+
|
|
255
|
+
# Convert sensor_data to dict format at the final stage
|
|
256
|
+
sensor_data_dict = {
|
|
257
|
+
sensor_id: det_obj.to_dict()
|
|
258
|
+
for sensor_id, det_obj in group['sensor_data'].items()
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
# Create the fused RoadUserPoint
|
|
262
|
+
road_user_point = RoadUserPoint(
|
|
263
|
+
x=group['weighted_lat'],
|
|
264
|
+
y=group['weighted_lon'],
|
|
265
|
+
category=most_common_class,
|
|
266
|
+
confidence=group['max_confidence'],
|
|
267
|
+
)
|
|
268
|
+
road_user_point.sensor_data = sensor_data_dict
|
|
269
|
+
|
|
270
|
+
return road_user_point
|
|
271
|
+
|
|
272
|
+
def fuse(self, detection_buffer: Dict[str, DetectionResult2D]) -> List[RoadUserPoint]:
|
|
273
|
+
"""
|
|
274
|
+
Fuse detections from multiple sensors using Hungarian matching.
|
|
275
|
+
|
|
276
|
+
Algorithm:
|
|
277
|
+
1. For the first sensor, create a group for each detection (single object groups)
|
|
278
|
+
2. For each subsequent sensor:
|
|
279
|
+
a. Use Hungarian algorithm to match detections to existing groups
|
|
280
|
+
b. For matched pairs within distance_threshold, add detection to group and update weighted location
|
|
281
|
+
c. For unmatched detections, create new single-object groups
|
|
282
|
+
3. Convert all groups to RoadUserPoints
|
|
283
|
+
|
|
284
|
+
:param detection_buffer: dict mapping sensor_id to DetectionResult2D
|
|
285
|
+
:return: list of fused RoadUserPoint instances
|
|
286
|
+
"""
|
|
287
|
+
# Step 1: Filter and organize detections by sensor
|
|
288
|
+
detections_by_sensor = self._filter_detections_by_sensor(detection_buffer)
|
|
289
|
+
|
|
290
|
+
if not detections_by_sensor:
|
|
291
|
+
return []
|
|
292
|
+
|
|
293
|
+
# Get list of sensors that have detections (preserve order from sensor_list)
|
|
294
|
+
active_sensors = [s for s in self.sensor_list if s in detections_by_sensor]
|
|
295
|
+
|
|
296
|
+
if not active_sensors:
|
|
297
|
+
return []
|
|
298
|
+
|
|
299
|
+
# Step 2: Initialize groups with first sensor's detections
|
|
300
|
+
first_sensor = active_sensors[0]
|
|
301
|
+
groups = []
|
|
302
|
+
for det in detections_by_sensor[first_sensor]:
|
|
303
|
+
group = self._create_group_from_detection(det, first_sensor)
|
|
304
|
+
groups.append(group)
|
|
305
|
+
|
|
306
|
+
# Step 3: Process remaining sensors
|
|
307
|
+
for sensor_id in active_sensors[1:]:
|
|
308
|
+
detections = detections_by_sensor[sensor_id]
|
|
309
|
+
|
|
310
|
+
# Hungarian matching against existing groups
|
|
311
|
+
groups, unmatched_detections = self._hungarian_match(groups, detections, sensor_id)
|
|
312
|
+
|
|
313
|
+
# Create new groups for unmatched detections
|
|
314
|
+
for det in unmatched_detections:
|
|
315
|
+
group = self._create_group_from_detection(det, sensor_id)
|
|
316
|
+
groups.append(group)
|
|
317
|
+
|
|
318
|
+
# Step 4: Convert groups to RoadUserPoints
|
|
319
|
+
fused_results = []
|
|
320
|
+
for group in groups:
|
|
321
|
+
road_user_point = self._group_to_road_user_point(group)
|
|
322
|
+
fused_results.append(road_user_point)
|
|
323
|
+
|
|
324
|
+
return fused_results
|
|
325
|
+
|