msight-vision 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cli/__init__.py ADDED
File without changes
@@ -0,0 +1,10 @@
1
+ from msight_vision.msight_core import DetectionResults2DViewerNode
2
+ from msight_core.utils import get_node_config_from_args, get_default_arg_parser
3
+
4
+ def main():
5
+ parser = get_default_arg_parser(description="Launch Detection Results 2D Viewer Node", node_class=DetectionResults2DViewerNode)
6
+ args = parser.parse_args()
7
+ detection_node = DetectionResults2DViewerNode(
8
+ configs=get_node_config_from_args(args)
9
+ )
10
+ detection_node.spin()
@@ -0,0 +1,23 @@
1
+ from msight_vision.msight_core import FuserNode
2
+ from msight_core.utils import get_node_config_from_args, get_default_arg_parser
3
+ import time
4
+
5
+ def main():
6
+ parser = get_default_arg_parser(description="Launch Fuser Node", node_class=FuserNode)
7
+ parser.add_argument("--fusion-config", "-fc", type=str, required=True, help="Path to the configuration file")
8
+ parser.add_argument("--wait", "-w", type=int, default=0, help="Wait time before starting the node (in seconds)")
9
+ args = parser.parse_args()
10
+ if args.wait > 0:
11
+ print(f"Waiting for {args.wait} seconds before starting the node...")
12
+ time.sleep(args.wait)
13
+
14
+ configs = get_node_config_from_args(args)
15
+
16
+ detection_node = FuserNode(
17
+ configs,
18
+ args.fusion_config,
19
+ )
20
+ detection_node.spin()
21
+
22
+ if __name__ == "__main__":
23
+ main()
@@ -0,0 +1,22 @@
1
+ from msight_vision.msight_core import FiniteDifferenceStateEstimatorNode
2
+ from msight_core.utils import get_node_config_from_args, get_default_arg_parser
3
+ import time
4
+
5
+ def main():
6
+ parser = get_default_arg_parser(description="Launch Finite Difference State Estimator Node", node_class=FiniteDifferenceStateEstimatorNode)
7
+ parser.add_argument("--estimator-configs", "-ec", type=str, required=True, help="Path to the configuration file")
8
+ parser.add_argument("--wait", "-w", type=int, default=0, help="Wait time before starting the node (in seconds)")
9
+ args = parser.parse_args()
10
+
11
+ if args.wait > 0:
12
+ print(f"Waiting for {args.wait} seconds before starting the node...")
13
+ time.sleep(args.wait)
14
+ configs = get_node_config_from_args(args)
15
+ detection_node = FiniteDifferenceStateEstimatorNode(
16
+ configs,
17
+ args.estimator_configs,
18
+ )
19
+ detection_node.spin()
20
+
21
+ if __name__ == "__main__":
22
+ main()
@@ -0,0 +1,23 @@
1
+ from msight_vision.msight_core import RoadUserListViewerNode
2
+ from msight_core.utils import get_node_config_from_args, get_default_arg_parser
3
+
4
+ def main():
5
+ parser = get_default_arg_parser(description="Launch Road User List Viewer Node", node_class=RoadUserListViewerNode)
6
+ parser.add_argument("--basemap", type=str, required=True, help="Path to the basemap image")
7
+ parser.add_argument("--show-trajectory", action='store_true', help="Flag to show trajectory")
8
+ parser.add_argument("--show-heading", action='store_true', help="Flag to show heading")
9
+ args = parser.parse_args()
10
+ configs = get_node_config_from_args(args)
11
+ # sub_topic = get_topic(redis_client, "example_fused_results")
12
+
13
+ detection_node = RoadUserListViewerNode(
14
+ configs,
15
+ args.basemap,
16
+ args.show_trajectory,
17
+ # False,
18
+ args.show_heading,
19
+ )
20
+ detection_node.spin()
21
+
22
+ if __name__ == "__main__":
23
+ main()
@@ -0,0 +1,24 @@
1
+ from msight_vision.msight_core import SortTrackerNode
2
+ from msight_core.utils import get_node_config_from_args, get_default_arg_parser
3
+ import time
4
+
5
+ def main():
6
+ parser = get_default_arg_parser(description="Launch SORT Tracker Node", node_class=SortTrackerNode)
7
+ parser.add_argument("--tracking-configs", "-tc", type=str, required=True,
8
+ help="Path to the configuration file")
9
+ parser.add_argument("--wait", "-w", type=int, default=0, help="Wait time before starting the node (in seconds)")
10
+ args = parser.parse_args()
11
+
12
+ if args.wait > 0:
13
+ print(f"Waiting for {args.wait} seconds before starting the node...")
14
+ time.sleep(args.wait)
15
+
16
+ configs = get_node_config_from_args(args)
17
+ detection_node = SortTrackerNode(
18
+ configs,
19
+ args.tracking_configs,
20
+ )
21
+ detection_node.spin()
22
+
23
+ if __name__ == "__main__":
24
+ main()
@@ -0,0 +1,22 @@
1
+ from msight_vision.msight_core import YoloOneStageDetectionNode
2
+ from msight_core.utils import get_node_config_from_args, get_default_arg_parser
3
+ import time
4
+
5
+ def main():
6
+ parser = get_default_arg_parser(description="Launch YOLO One-Stage Detection Node", node_class=YoloOneStageDetectionNode)
7
+ parser.add_argument("--det-configs", "-dc", type=str, required=True, help="Path to the configuration file")
8
+ parser.add_argument("--wait", "-w", type=int, default=0, help="Wait time before starting the node (in seconds)")
9
+ args = parser.parse_args()
10
+
11
+ if args.wait > 0:
12
+ print(f"Waiting for {args.wait} seconds before starting the node...")
13
+ time.sleep(args.wait)
14
+ configs = get_node_config_from_args(args)
15
+ detection_node = YoloOneStageDetectionNode(
16
+ configs,
17
+ args.det_configs,
18
+ )
19
+ detection_node.spin()
20
+
21
+ if __name__ == "__main__":
22
+ main()
@@ -0,0 +1,8 @@
1
+ from importlib.metadata import version
2
+ from .detector_yolo import YoloDetector, Yolo26Detector
3
+ from .localizer import HashLocalizer
4
+ from .tracker import SortTracker
5
+ from .warper import ClassicWarper, ClassicWarperWithExternalUpdate
6
+ from .fuser import FuserBase
7
+ from .state_estimator import FiniteDifferenceStateEstimator
8
+ __version__ = version("msight_vision")
msight_vision/base.py ADDED
@@ -0,0 +1,99 @@
1
+ from msight_base import DetectionResultBase, DetectedObjectBase, RoadUserPoint
2
+ import numpy as np
3
+ from typing import List, Dict
4
+
5
+ class DetectedObject2D(DetectedObjectBase):
6
+ """Detected object for 2D images."""
7
+
8
+ def __init__(self, box: list, class_id: int, score: float, pixel_bottom_center: List[float], obj_id: str = None, lat: float = None, lon: float = None, x: float = None, y: float = None):
9
+ """
10
+ Initialize the detected object.
11
+ :param box: bounding box coordinates (x1, y1, x2, y2)
12
+ :param class_id: class ID of the detected object
13
+ :param score: confidence score of the detection
14
+ :param obj_id: unique ID of the detected object (optional)
15
+ :param lat: latitude of the detected object (optional)
16
+ :param lon: longitude of the detected object (optional)
17
+ :param x: x coordinate in the coordination of interest like utm of the detected object (optional)
18
+ :param y: y coordinate in the coordination of intererst like utm of the detected object (optional)
19
+ """
20
+ super().__init__()
21
+ self.box = box
22
+ self.class_id = class_id
23
+ self.score = score
24
+ self.pixel_bottom_center = pixel_bottom_center
25
+ self.obj_id = obj_id
26
+ self.lat = lat
27
+ self.lon = lon
28
+ self.x = x
29
+ self.y = y
30
+
31
+ def to_dict(self):
32
+ """
33
+ Convert the detected object to a dictionary.
34
+ :return: dictionary representation of the detected object
35
+ """
36
+ return {
37
+ "box": self.box,
38
+ "class_id": self.class_id,
39
+ "score": self.score,
40
+ "obj_id": self.obj_id,
41
+ "lat": self.lat,
42
+ "lon": self.lon,
43
+ "x": self.x,
44
+ "y": self.y,
45
+ "pixel_bottom_center": self.pixel_bottom_center,
46
+ }
47
+
48
+ @staticmethod
49
+ def from_dict(data: dict):
50
+ """
51
+ Create a DetectedObject2D instance from a dictionary.
52
+ :param data: dictionary representation of the detected object
53
+ :return: DetectedObject2D instance
54
+ """
55
+ return DetectedObject2D(
56
+ box=data["box"],
57
+ class_id=data["class_id"],
58
+ score=data["score"],
59
+ obj_id=data.get("obj_id"),
60
+ lat=data.get("lat") or None,
61
+ lon=data.get("lon") or None,
62
+ x=data.get("x") or None,
63
+ y=data.get("y") or None,
64
+ pixel_bottom_center=data.get("pixel_bottom_center") or None,
65
+ )
66
+
67
+ def __repr__(self):
68
+ return f"DetectedObject2D(box={self.box}, class_id={self.class_id}, score={self.score}, obj_id={self.obj_id}, lat={self.lat}, lon={self.lon}, x={self.x}, y={self.y})"
69
+
70
+ class DetectionResult2D(DetectionResultBase):
71
+ """Detection result for 2D images."""
72
+
73
+ def __init__(self, object_list: List[DetectedObject2D], timestamp: int, sensor_type: str):
74
+ """
75
+ Initialize the detection result.
76
+ :param detected_objects: list of detected objects
77
+ """
78
+ super().__init__(object_list, timestamp, sensor_type)
79
+
80
+ class ImageDetector2DBase:
81
+ def detect(self, image: np.ndarray) -> DetectionResult2D:
82
+ """
83
+ Detector base, a detector detects objects in the image.
84
+ :param image: input image
85
+ :return: list of detected objects
86
+ """
87
+ raise NotImplementedError("detect method not implemented")
88
+
89
+ class TrackerBase:
90
+ def __init__(self):
91
+ pass
92
+
93
+ def track(self, list) ->Dict[str, RoadUserPoint]:
94
+ """
95
+ Track the detected objects in the image.
96
+ :param detection_result: DetectionResult2D instance
97
+ :return: updated DetectionResult2D instance with tracking information
98
+ """
99
+ raise NotImplementedError("track method not implemented")
@@ -0,0 +1,87 @@
1
+ from numpy import ndarray
2
+ from msight_vision.base import DetectionResult2D, DetectedObject2D
3
+ from .base import ImageDetector2DBase
4
+ from ultralytics import YOLO
5
+ from pathlib import Path
6
+
7
+ class YoloDetector(ImageDetector2DBase):
8
+ """YOLOv5 detector for 2D images."""
9
+
10
+ def __init__(self, model_path: Path, device: str = "cpu", confthre: float = 0.25, nmsthre: float = 0.45, fp16: bool = False, class_agnostic_nms: bool = False):
11
+ """
12
+ Initialize the YOLO detector.
13
+ :param model_path: path to the YOLO model
14
+ :param device: device to run the model on (e.g., 'cpu', 'cuda')
15
+ """
16
+ super().__init__()
17
+ self.model = YOLO(str(model_path))
18
+ self.device = device
19
+ self.confthre = confthre
20
+ self.nmsthre = nmsthre
21
+ self.fp16 = fp16
22
+ self.class_agnostic_nms = class_agnostic_nms
23
+
24
+
25
+ def convert_yolo_result_to_detection_result(self, yolo_output_results, timestamp, sensor_type):
26
+ """
27
+ Convert YOLO output results to DetectionResult2D.
28
+ :param yolo_output_results: YOLO output results
29
+ :param timestamp: timestamp of the image
30
+ :param sensor_type: type of the sensor
31
+ :return: DetectionResult2D instance
32
+ """
33
+ # Convert YOLO output to DetectionResult2D
34
+ bboxes = yolo_output_results[0].boxes.xyxy.cpu().numpy()
35
+ confs = yolo_output_results[0].boxes.conf.cpu().numpy()
36
+ class_ids = yolo_output_results[0].boxes.cls.cpu().numpy()
37
+
38
+ detected_objects = []
39
+ for i in range(len(bboxes)):
40
+ box = bboxes[i]
41
+ class_id = int(class_ids[i])
42
+ score = float(confs[i])
43
+ # calculate the center coordinates of the bounding box
44
+ center_x = float((box[0] + box[2]) / 2)
45
+ center_y = float((box[1] + box[3]) / 2)
46
+ # print(class_id)
47
+ detected_object = DetectedObject2D(
48
+ box=[float(box[0]), float(box[1]), float(box[2]), float(box[3])],
49
+ class_id=class_id,
50
+ score=score,
51
+ pixel_bottom_center=[center_x, center_y],
52
+ )
53
+ detected_objects.append(detected_object)
54
+
55
+ detection_result = DetectionResult2D(
56
+ detected_objects,
57
+ timestamp,
58
+ sensor_type,
59
+ )
60
+
61
+ return detection_result
62
+
63
+ def detect(self, image: ndarray, timestamp, sensor_type) -> DetectionResult2D:
64
+ yolo_output_results = self.model(image, device=self.device, conf=self.confthre, iou=self.nmsthre, half=self.fp16, verbose=False, agnostic_nms=self.class_agnostic_nms)
65
+ ## Convert results to DetectionResult2D
66
+ detection_result = self.convert_yolo_result_to_detection_result(
67
+ yolo_output_results,
68
+ timestamp,
69
+ sensor_type,
70
+ )
71
+ return detection_result
72
+
73
+ class Yolo26Detector(YoloDetector):
74
+ """YOLOv2.6 detector for 2D images."""
75
+ def __init__(self, model_path: Path, device: str = "cpu", confthre: float = 0.25, nmsthre: float = 0.45, fp16: bool = False, class_agnostic_nms: bool = False, end2end: bool = False):
76
+ super().__init__(model_path, device, confthre, nmsthre, fp16, class_agnostic_nms)
77
+
78
+ self.end2end = end2end
79
+ def detect(self, image: ndarray, timestamp, sensor_type) -> DetectionResult2D:
80
+ yolo_output_results = self.model(image, device=self.device, conf=self.confthre, iou=self.nmsthre, half=self.fp16, verbose=False, agnostic_nms=self.class_agnostic_nms, end2end=self.end2end)
81
+ ## Convert results to DetectionResult2D
82
+ detection_result = self.convert_yolo_result_to_detection_result(
83
+ yolo_output_results,
84
+ timestamp,
85
+ sensor_type,
86
+ )
87
+ return detection_result
msight_vision/fuser.py ADDED
@@ -0,0 +1,325 @@
1
+ from typing import Dict, List, Tuple
2
+ from .base import DetectionResult2D
3
+ from geopy.distance import geodesic
4
+ from .utils import detection_to_roaduser_point
5
+ from msight_base import RoadUserPoint
6
+ from scipy.optimize import linear_sum_assignment
7
+ from shapely.geometry import Point, Polygon
8
+ import numpy as np
9
+
10
+ class FuserBase:
11
+ def fuse(self, results: Dict[str, DetectionResult2D]) -> List[RoadUserPoint]:
12
+ """
13
+ Fuses the data from different sources into a single output.
14
+ :param data: The input data to be fused.
15
+ :return: The fused output.
16
+ """
17
+ raise NotImplementedError("FuserBase is an abstract class and cannot be instantiated directly.")
18
+
19
+ ## This is a simple example of a fuser that combines the results from different cameras, which works at the roundabout of State and Ellsworth in Smart Intersection Project.
20
+ class StateEllsworthFuser(FuserBase):
21
+ '''
22
+ This is a simple example of a fuser that combines the results from different cameras, which works at the roundabouot of State and Ellsworth.
23
+ '''
24
+ def __init__(self):
25
+ self.lat1 = 42.229379
26
+ self.lon1 = -83.739003
27
+ self.lat2 = 42.229444
28
+ self.lon2 = -83.739013
29
+
30
+ def fuse(self, detection_buffer: Dict[str, DetectionResult2D]) -> List[RoadUserPoint]:
31
+ fused_vehicle_list = []
32
+
33
+ vehicle_list = detection_buffer['gs_State_Ellsworth_NW'].object_list
34
+ for v in vehicle_list: # cam_ne
35
+ if v.lat > self.lat1 and v.lon > self.lon1:
36
+ fused_vehicle_list.append(detection_to_roaduser_point(v, 'gs_State_Ellsworth_NW'))
37
+
38
+ vehicle_list = detection_buffer['gs_State_Ellsworth_NE'].object_list
39
+ for v in vehicle_list: # cam_nw
40
+ if v.lat > self.lat2 and v.lon < self.lon2:
41
+ fused_vehicle_list.append(detection_to_roaduser_point(v, 'gs_State_Ellsworth_NE'))
42
+
43
+ vehicle_list = detection_buffer['gs_State_Ellsworth_SE'].object_list
44
+ for v in vehicle_list: # cam_se
45
+ if v.lat < self.lat1 and v.lon > self.lon1:
46
+ fused_vehicle_list.append(detection_to_roaduser_point(v, 'gs_State_Ellsworth_SE'))
47
+
48
+ vehicle_list = detection_buffer['gs_State_Ellsworth_SW'].object_list
49
+ for v in vehicle_list: # cam_sw
50
+ if v.lat < self.lat2 and v.lon < self.lon2:
51
+ fused_vehicle_list.append(detection_to_roaduser_point(v, 'gs_State_Ellsworth_SW'))
52
+ return fused_vehicle_list
53
+
54
+ class HungarianFuser(FuserBase):
55
+ """
56
+ A fuser that matches detections from multiple sensors based on spatial proximity
57
+ using Hungarian algorithm and fuses their locations using weighted averaging.
58
+ """
59
+ def __init__(self, coverage_zones: dict, sensor_locations: dict = None, distance_threshold: float = 5.0):
60
+ """
61
+ Initialize the HungarianFuser.
62
+ :param coverage_zones: dict mapping sensor_id to a polygon defining the sensor's coverage zone.
63
+ Each polygon is a list of (lat, lon) tuples forming a closed polygon.
64
+ Example: [(lat1, lon1), (lat2, lon2), (lat3, lon3), ...]
65
+ :param sensor_locations: dict mapping sensor_id to (lat, lon) tuple of the sensor's location.
66
+ If provided, weights are computed as 1/distance_to_sensor^2.
67
+ If not provided, bounding box area is used as weight.
68
+ :param distance_threshold: maximum distance (in meters) to consider two detections as the same object.
69
+ """
70
+ self.coverage_zones = coverage_zones
71
+ self.sensor_locations = sensor_locations
72
+ self.distance_threshold = distance_threshold
73
+ self.sensor_list = list(coverage_zones.keys())
74
+
75
+ # Pre-compute Shapely Polygon objects for efficient point-in-polygon checks
76
+ self._coverage_polygons = {
77
+ sensor_id: Polygon(polygon) if polygon else None
78
+ for sensor_id, polygon in coverage_zones.items()
79
+ }
80
+
81
+ def _is_in_coverage(self, detected_object, sensor_id: str) -> bool:
82
+ """
83
+ Check if a detected object is within the sensor's coverage zone.
84
+ :param detected_object: DetectedObject2D instance
85
+ :param sensor_id: sensor identifier
86
+ :return: True if in coverage zone, False otherwise
87
+ """
88
+ polygon = self._coverage_polygons.get(sensor_id)
89
+ if polygon is None:
90
+ return True # No coverage filter defined, include all
91
+ point = Point(detected_object.lat, detected_object.lon)
92
+ return polygon.contains(point)
93
+
94
+ def _compute_weight(self, detected_object, sensor_id: str) -> float:
95
+ """
96
+ Compute the weight for a detected object.
97
+ If sensor location is available, use 1/distance_to_sensor^2.
98
+ Otherwise, use the bounding box area.
99
+ :param detected_object: DetectedObject2D instance
100
+ :param sensor_id: sensor identifier
101
+ :return: weight value
102
+ """
103
+ if self.sensor_locations is not None and sensor_id in self.sensor_locations:
104
+ sensor_lat, sensor_lon = self.sensor_locations[sensor_id]
105
+ # Compute geodesic distance to sensor in meters
106
+ dist = geodesic((detected_object.lat, detected_object.lon), (sensor_lat, sensor_lon)).meters
107
+ dist_sq = dist ** 2
108
+ if dist_sq < 1e-10:
109
+ dist_sq = 1e-10 # Avoid division by zero
110
+ return 1.0 / dist_sq
111
+ else:
112
+ # Use bounding box area as weight
113
+ # box is [x1, y1, x2, y2]
114
+ box = detected_object.box
115
+ width = box[2] - box[0]
116
+ height = box[3] - box[1]
117
+ area = width * height
118
+ return area if area > 0 else 1.0
119
+
120
+ def _compute_distance_to_group(self, detected_object, group: dict) -> float:
121
+ """
122
+ Compute the geodesic distance between a detected object and a group's weighted location.
123
+ :param detected_object: DetectedObject2D instance
124
+ :param group: group dict containing 'weighted_lat' and 'weighted_lon'
125
+ :return: distance in meters
126
+ """
127
+ return geodesic(
128
+ (detected_object.lat, detected_object.lon),
129
+ (group['weighted_lat'], group['weighted_lon'])
130
+ ).meters
131
+
132
+ def _filter_detections_by_sensor(self, detection_buffer: Dict[str, DetectionResult2D]) -> Dict[str, List]:
133
+ """
134
+ Filter detections by coverage zone and organize by sensor.
135
+ :param detection_buffer: dict mapping sensor_id to DetectionResult2D
136
+ :return: dict mapping sensor_id to list of valid DetectedObject2D instances
137
+ """
138
+ detections_by_sensor = {}
139
+ for sensor_id in self.sensor_list:
140
+ if sensor_id not in detection_buffer:
141
+ continue
142
+ detection_result = detection_buffer[sensor_id]
143
+ valid_detections = []
144
+ for detected_object in detection_result.object_list:
145
+ # Skip objects without valid lat/lon
146
+ if detected_object.lat is None or detected_object.lon is None:
147
+ continue
148
+ # Filter by coverage zone
149
+ if self._is_in_coverage(detected_object, sensor_id):
150
+ valid_detections.append(detected_object)
151
+ if valid_detections:
152
+ detections_by_sensor[sensor_id] = valid_detections
153
+ return detections_by_sensor
154
+
155
+ def _create_group_from_detection(self, detected_object, sensor_id: str) -> dict:
156
+ """
157
+ Create a new group from a single detection.
158
+ :param detected_object: DetectedObject2D instance
159
+ :param sensor_id: sensor identifier
160
+ :return: group dict
161
+ """
162
+ weight = self._compute_weight(detected_object, sensor_id)
163
+ return {
164
+ 'weighted_lat': detected_object.lat,
165
+ 'weighted_lon': detected_object.lon,
166
+ 'total_weight': weight,
167
+ 'weighted_lat_sum': detected_object.lat * weight,
168
+ 'weighted_lon_sum': detected_object.lon * weight,
169
+ 'max_confidence': detected_object.score,
170
+ 'class_id_counts': {detected_object.class_id: 1},
171
+ 'sensor_data': {sensor_id: detected_object},
172
+ }
173
+
174
+ def _add_detection_to_group(self, group: dict, detected_object, sensor_id: str) -> None:
175
+ """
176
+ Add a detection to an existing group and update weighted location.
177
+ :param group: group dict to update
178
+ :param detected_object: DetectedObject2D instance
179
+ :param sensor_id: sensor identifier
180
+ """
181
+ weight = self._compute_weight(detected_object, sensor_id)
182
+
183
+ # Update weighted sums
184
+ group['weighted_lat_sum'] += detected_object.lat * weight
185
+ group['weighted_lon_sum'] += detected_object.lon * weight
186
+ group['total_weight'] += weight
187
+
188
+ # Update weighted location
189
+ group['weighted_lat'] = group['weighted_lat_sum'] / group['total_weight']
190
+ group['weighted_lon'] = group['weighted_lon_sum'] / group['total_weight']
191
+
192
+ # Update confidence
193
+ if detected_object.score > group['max_confidence']:
194
+ group['max_confidence'] = detected_object.score
195
+
196
+ # Update class_id counts
197
+ class_id = detected_object.class_id
198
+ group['class_id_counts'][class_id] = group['class_id_counts'].get(class_id, 0) + 1
199
+
200
+ # Store sensor data (keep the detected object directly)
201
+ group['sensor_data'][sensor_id] = detected_object
202
+
203
+ def _hungarian_match(self, groups: List[dict], detections: List, sensor_id: str) -> Tuple[List[dict], List]:
204
+ """
205
+ Use Hungarian algorithm to match detections to existing groups.
206
+ :param groups: list of existing group dicts
207
+ :param detections: list of DetectedObject2D instances from a single sensor
208
+ :param sensor_id: sensor identifier
209
+ :return: (updated groups, unmatched detections)
210
+ """
211
+ if not groups or not detections:
212
+ return groups, detections
213
+
214
+ n_groups = len(groups)
215
+ n_detections = len(detections)
216
+
217
+ # Build cost matrix (distance between each detection and each group)
218
+ cost_matrix = np.zeros((n_detections, n_groups))
219
+ for i, det in enumerate(detections):
220
+ for j, group in enumerate(groups):
221
+ dist = self._compute_distance_to_group(det, group)
222
+ cost_matrix[i, j] = dist
223
+
224
+ # Apply Hungarian algorithm
225
+ row_ind, col_ind = linear_sum_assignment(cost_matrix)
226
+
227
+ matched_detection_indices = set()
228
+ matched_group_indices = set()
229
+
230
+ # Process matches
231
+ for det_idx, group_idx in zip(row_ind, col_ind):
232
+ dist = cost_matrix[det_idx, group_idx]
233
+ if dist <= self.distance_threshold:
234
+ # Valid match - add detection to group
235
+ self._add_detection_to_group(groups[group_idx], detections[det_idx], sensor_id)
236
+ matched_detection_indices.add(det_idx)
237
+ matched_group_indices.add(group_idx)
238
+
239
+ # Collect unmatched detections
240
+ unmatched_detections = [
241
+ detections[i] for i in range(n_detections) if i not in matched_detection_indices
242
+ ]
243
+
244
+ return groups, unmatched_detections
245
+
246
+ def _group_to_road_user_point(self, group: dict) -> RoadUserPoint:
247
+ """
248
+ Convert a group to a RoadUserPoint.
249
+ :param group: group dict
250
+ :return: RoadUserPoint instance
251
+ """
252
+ # Determine the most common class_id
253
+ most_common_class = max(group['class_id_counts'], key=group['class_id_counts'].get)
254
+
255
+ # Convert sensor_data to dict format at the final stage
256
+ sensor_data_dict = {
257
+ sensor_id: det_obj.to_dict()
258
+ for sensor_id, det_obj in group['sensor_data'].items()
259
+ }
260
+
261
+ # Create the fused RoadUserPoint
262
+ road_user_point = RoadUserPoint(
263
+ x=group['weighted_lat'],
264
+ y=group['weighted_lon'],
265
+ category=most_common_class,
266
+ confidence=group['max_confidence'],
267
+ )
268
+ road_user_point.sensor_data = sensor_data_dict
269
+
270
+ return road_user_point
271
+
272
+ def fuse(self, detection_buffer: Dict[str, DetectionResult2D]) -> List[RoadUserPoint]:
273
+ """
274
+ Fuse detections from multiple sensors using Hungarian matching.
275
+
276
+ Algorithm:
277
+ 1. For the first sensor, create a group for each detection (single object groups)
278
+ 2. For each subsequent sensor:
279
+ a. Use Hungarian algorithm to match detections to existing groups
280
+ b. For matched pairs within distance_threshold, add detection to group and update weighted location
281
+ c. For unmatched detections, create new single-object groups
282
+ 3. Convert all groups to RoadUserPoints
283
+
284
+ :param detection_buffer: dict mapping sensor_id to DetectionResult2D
285
+ :return: list of fused RoadUserPoint instances
286
+ """
287
+ # Step 1: Filter and organize detections by sensor
288
+ detections_by_sensor = self._filter_detections_by_sensor(detection_buffer)
289
+
290
+ if not detections_by_sensor:
291
+ return []
292
+
293
+ # Get list of sensors that have detections (preserve order from sensor_list)
294
+ active_sensors = [s for s in self.sensor_list if s in detections_by_sensor]
295
+
296
+ if not active_sensors:
297
+ return []
298
+
299
+ # Step 2: Initialize groups with first sensor's detections
300
+ first_sensor = active_sensors[0]
301
+ groups = []
302
+ for det in detections_by_sensor[first_sensor]:
303
+ group = self._create_group_from_detection(det, first_sensor)
304
+ groups.append(group)
305
+
306
+ # Step 3: Process remaining sensors
307
+ for sensor_id in active_sensors[1:]:
308
+ detections = detections_by_sensor[sensor_id]
309
+
310
+ # Hungarian matching against existing groups
311
+ groups, unmatched_detections = self._hungarian_match(groups, detections, sensor_id)
312
+
313
+ # Create new groups for unmatched detections
314
+ for det in unmatched_detections:
315
+ group = self._create_group_from_detection(det, sensor_id)
316
+ groups.append(group)
317
+
318
+ # Step 4: Convert groups to RoadUserPoints
319
+ fused_results = []
320
+ for group in groups:
321
+ road_user_point = self._group_to_road_user_point(group)
322
+ fused_results.append(road_user_point)
323
+
324
+ return fused_results
325
+