msight-vision 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/__init__.py +0 -0
- cli/launch_2d_viewer.py +10 -0
- cli/launch_custom_fuser.py +23 -0
- cli/launch_finite_difference_state_estimator.py +22 -0
- cli/launch_road_user_list_viewer.py +23 -0
- cli/launch_sort_tracker.py +24 -0
- cli/launch_yolo_onestage_detection.py +22 -0
- msight_vision/__init__.py +8 -0
- msight_vision/base.py +99 -0
- msight_vision/detector_yolo.py +87 -0
- msight_vision/fuser.py +325 -0
- msight_vision/localizer.py +32 -0
- msight_vision/msight_core/__init__.py +6 -0
- msight_vision/msight_core/detection.py +103 -0
- msight_vision/msight_core/fusion.py +64 -0
- msight_vision/msight_core/state_estimation.py +38 -0
- msight_vision/msight_core/tracking.py +31 -0
- msight_vision/msight_core/viewer.py +55 -0
- msight_vision/msight_core/warper.py +98 -0
- msight_vision/state_estimator.py +121 -0
- msight_vision/tracker.py +525 -0
- msight_vision/utils/__init__.py +3 -0
- msight_vision/utils/data.py +80 -0
- msight_vision/utils/typing.py +18 -0
- msight_vision/utils/vis.py +17 -0
- msight_vision/warper.py +89 -0
- msight_vision-0.1.0.dist-info/METADATA +28 -0
- msight_vision-0.1.0.dist-info/RECORD +31 -0
- msight_vision-0.1.0.dist-info/WHEEL +5 -0
- msight_vision-0.1.0.dist-info/entry_points.txt +7 -0
- msight_vision-0.1.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
class LocalizerBase:
|
|
2
|
+
def __init__(self):
|
|
3
|
+
pass
|
|
4
|
+
|
|
5
|
+
def localize(self):
|
|
6
|
+
raise NotImplementedError(
|
|
7
|
+
"This method should be overridden by subclasses")
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class HashLocalizer(LocalizerBase):
|
|
11
|
+
"""Hash-based localizer for 2D images.
|
|
12
|
+
This localizer looks up the pixel values in a hash table to find the corresponding location in the image.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def __init__(self, lat_map, lon_map):
|
|
16
|
+
super().__init__()
|
|
17
|
+
self.lat_map = lat_map
|
|
18
|
+
self.lon_map = lon_map
|
|
19
|
+
|
|
20
|
+
def localize(self, detection2d_result):
|
|
21
|
+
# Perform localization using the model and config
|
|
22
|
+
# This is a placeholder implementation
|
|
23
|
+
for obj in detection2d_result.object_list:
|
|
24
|
+
# Assuming obj has a method to get the pixel coordinates
|
|
25
|
+
bottom_center_x = int(obj.pixel_bottom_center[0])
|
|
26
|
+
bottom_center_y = int(obj.pixel_bottom_center[1])
|
|
27
|
+
lat = self.lat_map[bottom_center_y, bottom_center_x,]
|
|
28
|
+
lon = self.lon_map[bottom_center_y, bottom_center_x,]
|
|
29
|
+
obj.lat = lat
|
|
30
|
+
obj.lon = lon
|
|
31
|
+
# print(f"Object {obj.class_id} localized to lat: {lat}, lon: {lon}")
|
|
32
|
+
return detection2d_result
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
from .detection import YoloOneStageDetectionNode
|
|
2
|
+
from .fusion import FuserNode
|
|
3
|
+
from .tracking import SortTrackerNode
|
|
4
|
+
from .viewer import RoadUserListViewerNode, DetectionResults2DViewerNode
|
|
5
|
+
from .warper import WarperMatrixUpdaterNode
|
|
6
|
+
from .state_estimation import FiniteDifferenceStateEstimatorNode
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
from msight_core.nodes import DataProcessingNode, NodeConfig
|
|
2
|
+
from msight_core.data import ImageData, DetectionResultsData
|
|
3
|
+
import yaml
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
import numpy as np
|
|
6
|
+
from .. import YoloDetector, HashLocalizer, ClassicWarperWithExternalUpdate
|
|
7
|
+
import torch
|
|
8
|
+
import time
|
|
9
|
+
from msight_core.utils import get_redis_client
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
13
|
+
|
|
14
|
+
def load_locmaps(loc_maps_path):
|
|
15
|
+
"""
|
|
16
|
+
Load localization maps from the specified path.
|
|
17
|
+
:param loc_maps_path: path to the localization maps in the config file
|
|
18
|
+
:return: localization maps
|
|
19
|
+
"""
|
|
20
|
+
result = {key: np.load(item) for key, item in loc_maps_path.items()}
|
|
21
|
+
return result
|
|
22
|
+
|
|
23
|
+
class YoloOneStageDetectionNode(DataProcessingNode):
|
|
24
|
+
default_configs = NodeConfig(
|
|
25
|
+
publish_topic_data_type=DetectionResultsData
|
|
26
|
+
)
|
|
27
|
+
def __init__(self, configs, det_configs_path):
|
|
28
|
+
super().__init__(configs)
|
|
29
|
+
self.det_config_path = Path(det_configs_path)
|
|
30
|
+
self.detector = None
|
|
31
|
+
with open(self.det_config_path, "r") as f:
|
|
32
|
+
self.det_config = yaml.safe_load(f)
|
|
33
|
+
self.no_warp = self.det_config['warper_config']['no_warp']
|
|
34
|
+
self.model_path = self.det_config["model_config"]["ckpt_path"]
|
|
35
|
+
self.confthre = self.det_config["model_config"]["confthre"]
|
|
36
|
+
self.nmsthre = self.det_config["model_config"]["nmsthre"]
|
|
37
|
+
self.class_agnostic_nms = self.det_config["model_config"]["class_agnostic_nms"]
|
|
38
|
+
self.logger.info(f"Initializing YoloOneStageDetectionNode with model path: {self.model_path}, no_warp: {self.no_warp}, confthre: {self.confthre}, nmsthre: {self.nmsthre}, class_agnostic_nms: {self.class_agnostic_nms}")
|
|
39
|
+
self.detector = YoloDetector(model_path=Path(self.model_path), device=device, confthre=self.confthre, nmsthre=self.nmsthre, fp16=False, class_agnostic_nms=self.class_agnostic_nms)
|
|
40
|
+
loc_maps_path = self.det_config["loc_maps"]
|
|
41
|
+
loc_maps = load_locmaps(loc_maps_path)
|
|
42
|
+
self.localizers = {key: HashLocalizer(item['x_map'], item['y_map']) for key, item in loc_maps.items()}
|
|
43
|
+
if not self.no_warp:
|
|
44
|
+
self.warper = ClassicWarperWithExternalUpdate()
|
|
45
|
+
self.warper_matrix_redis_prefix = self.det_config["warper_config"]["redis_prefix"]
|
|
46
|
+
else:
|
|
47
|
+
self.warper_matrix_redis_prefix = None
|
|
48
|
+
self.include_sensor_data_in_result = self.det_config["det_config"].get("include_sensor_data_in_result", False)
|
|
49
|
+
self.sensor_type = self.det_config["det_config"].get("sensor_type", "fisheye")
|
|
50
|
+
|
|
51
|
+
def get_warp_matrix_from_redis(self, sensor_name):
|
|
52
|
+
redis_client = get_redis_client()
|
|
53
|
+
key = self.warper_matrix_redis_prefix + f":{sensor_name}"
|
|
54
|
+
warp_matrix_str = redis_client.get(key)
|
|
55
|
+
|
|
56
|
+
if warp_matrix_str is None:
|
|
57
|
+
self.logger.warning(f"No warp matrix found in Redis for sensor: {sensor_name}")
|
|
58
|
+
return None
|
|
59
|
+
|
|
60
|
+
# Decode bytes to string if needed
|
|
61
|
+
if isinstance(warp_matrix_str, bytes):
|
|
62
|
+
warp_matrix_str = warp_matrix_str.decode()
|
|
63
|
+
|
|
64
|
+
# Remove brackets and convert back to numpy array
|
|
65
|
+
warp_matrix = np.array(eval(warp_matrix_str))
|
|
66
|
+
|
|
67
|
+
# You may need to reshape it, for example to 3x3 if that's your matrix size
|
|
68
|
+
warp_matrix = warp_matrix.reshape((3, 3)) # adjust the shape if your matrix is different
|
|
69
|
+
|
|
70
|
+
return warp_matrix
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def process(self, data: ImageData):
|
|
74
|
+
self.logger.info(f"Processing image data from sensor: {data.sensor_name}, frame: {data.frame_id}")
|
|
75
|
+
start = time.time()
|
|
76
|
+
image = data.to_ndarray()
|
|
77
|
+
sensor_name = data.sensor_name
|
|
78
|
+
frame_id = data.frame_id
|
|
79
|
+
# cv2.imshow("image", image)
|
|
80
|
+
# cv2.waitKey(1)
|
|
81
|
+
# print(image.shape)
|
|
82
|
+
timestamp = data.capture_timestamp
|
|
83
|
+
if not self.no_warp:
|
|
84
|
+
# print(f"Image shape before warping: {image.shape}")
|
|
85
|
+
warping_matrix = self.get_warp_matrix_from_redis(sensor_name)
|
|
86
|
+
image = self.warper.warp(image, warping_matrix)
|
|
87
|
+
# print(f"Image shape after warping: {image.shape}")
|
|
88
|
+
# cv2.imshow("image", image)
|
|
89
|
+
# cv2.waitKey(1)
|
|
90
|
+
# print(image.shape)
|
|
91
|
+
result = self.detector.detect(image, timestamp, self.sensor_type)
|
|
92
|
+
localizer = self.localizers[sensor_name]
|
|
93
|
+
localizer.localize(result)
|
|
94
|
+
self.logger.info(f"Detection completed in {time.time() - start:.2f} seconds for sensor: {sensor_name}")
|
|
95
|
+
def is_number(val):
|
|
96
|
+
return isinstance(val, (int, float, np.number)) and np.isfinite(val)
|
|
97
|
+
result.object_list = [obj for obj in result.object_list if is_number(obj.lat) and is_number(obj.lon)]
|
|
98
|
+
raw_sensor_data = None
|
|
99
|
+
if self.include_sensor_data_in_result:
|
|
100
|
+
raw_sensor_data = data
|
|
101
|
+
detection_result_data = DetectionResultsData(result, sensor_frame_id=data.frame_id, capture_timestamp=data.capture_timestamp, creation_timestamp=time.time(), sensor_name=sensor_name, raw_sensor_data=raw_sensor_data)
|
|
102
|
+
# print(f"Detection results: {detection_result_data.to_dict()}")
|
|
103
|
+
return detection_result_data
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from msight_core.nodes import DataProcessingNode, NodeConfig
|
|
2
|
+
from msight_core.data import RoadUserListData, DetectionResultsData
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
import yaml
|
|
5
|
+
import importlib.util
|
|
6
|
+
import sys
|
|
7
|
+
import copy
|
|
8
|
+
|
|
9
|
+
def load_class_from_file(file_path: str, class_path: str):
|
|
10
|
+
module_name = Path(file_path).stem
|
|
11
|
+
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
|
12
|
+
if spec is None:
|
|
13
|
+
raise ImportError(f"Cannot load module from {file_path}")
|
|
14
|
+
module = importlib.util.module_from_spec(spec)
|
|
15
|
+
sys.modules[module_name] = module
|
|
16
|
+
spec.loader.exec_module(module)
|
|
17
|
+
|
|
18
|
+
# Resolve class from class_path
|
|
19
|
+
components = class_path.split(".")
|
|
20
|
+
cls = module
|
|
21
|
+
for comp in components[1:]: # Skip the module name
|
|
22
|
+
cls = getattr(cls, comp)
|
|
23
|
+
return cls
|
|
24
|
+
|
|
25
|
+
class FuserNode(DataProcessingNode):
|
|
26
|
+
default_configs = NodeConfig(
|
|
27
|
+
publish_topic_data_type=RoadUserListData
|
|
28
|
+
)
|
|
29
|
+
def __init__(self, configs, fusion_configs_path):
|
|
30
|
+
super().__init__(configs)
|
|
31
|
+
self.config_file_path = Path(fusion_configs_path)
|
|
32
|
+
with open(self.config_file_path, "r") as f:
|
|
33
|
+
self.fusion_configs = yaml.safe_load(f)
|
|
34
|
+
self.fuser_class_path = self.fusion_configs["fuser_config"]["class_path"]
|
|
35
|
+
self.fuser_file_path = self.fusion_configs["fuser_config"]["file_path"]
|
|
36
|
+
self.sensor_list = self.fusion_configs["fuser_config"]["sensor_list"]
|
|
37
|
+
FuserClass = load_class_from_file(self.fuser_file_path, self.fuser_class_path)
|
|
38
|
+
self.fuser = FuserClass()
|
|
39
|
+
self.buffer = {sensor: None for sensor in self.sensor_list}
|
|
40
|
+
assert configs.sensor_name is not None, "sensor_name must be provided in configs for the fusion node."
|
|
41
|
+
self.sensor_name = self.configs.sensor_name
|
|
42
|
+
|
|
43
|
+
def process(self, data: DetectionResultsData):
|
|
44
|
+
self.buffer[data.sensor_name] = data.detection_result
|
|
45
|
+
# print(data.detection_result)
|
|
46
|
+
buffer = copy.copy(self.buffer)
|
|
47
|
+
self.logger.info(f"Processing detection results from sensor: {data.sensor_name}")
|
|
48
|
+
sensor_name = data.sensor_name
|
|
49
|
+
if sensor_name not in self.sensor_list:
|
|
50
|
+
raise ValueError(f"Sensor {sensor_name} not in configured sensor list: {self.sensor_list}")
|
|
51
|
+
|
|
52
|
+
for _, detection_result in buffer.items():
|
|
53
|
+
if detection_result is None:
|
|
54
|
+
return None
|
|
55
|
+
self.logger.info(f"All data received. Fusing detection results for sensor.")
|
|
56
|
+
fused_result = self.fuser.fuse(buffer)
|
|
57
|
+
road_user_list_data = RoadUserListData(
|
|
58
|
+
road_user_list=fused_result,
|
|
59
|
+
capture_timestamp=data.capture_timestamp,
|
|
60
|
+
sensor_name=self.sensor_name
|
|
61
|
+
)
|
|
62
|
+
self.buffer = {sensor: None for sensor in self.sensor_list} # Reset buffer after fusion
|
|
63
|
+
# print(road_user_list_data)
|
|
64
|
+
return road_user_list_data
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from msight_core.nodes import DataProcessingNode, NodeConfig
|
|
2
|
+
from msight_core.data import RoadUserListData
|
|
3
|
+
from .. import FiniteDifferenceStateEstimator
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
import yaml
|
|
6
|
+
import time
|
|
7
|
+
|
|
8
|
+
class FiniteDifferenceStateEstimatorNode(DataProcessingNode):
|
|
9
|
+
default_configs = NodeConfig(
|
|
10
|
+
publish_topic_data_type=RoadUserListData
|
|
11
|
+
)
|
|
12
|
+
def __init__(self, configs, state_estimation_configs_path):
|
|
13
|
+
super().__init__(configs)
|
|
14
|
+
self.config_file_path = Path(state_estimation_configs_path)
|
|
15
|
+
with open(self.config_file_path, "r") as f:
|
|
16
|
+
self.state_estimation_configs = yaml.safe_load(f)
|
|
17
|
+
self.frame_rate = self.state_estimation_configs["state_estimator_config"].get("frame_rate", 5)
|
|
18
|
+
self.frame_interval = self.state_estimation_configs["state_estimator_config"].get("frame_interval", 1)
|
|
19
|
+
self.dist_threshold = self.state_estimation_configs["state_estimator_config"].get("dist_threshold", 4)
|
|
20
|
+
self.state_estimator = FiniteDifferenceStateEstimator(
|
|
21
|
+
frame_rate=self.frame_rate,
|
|
22
|
+
frame_interval=self.frame_interval,
|
|
23
|
+
dist_threshold=self.dist_threshold
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
def process(self, data: RoadUserListData) -> RoadUserListData:
|
|
27
|
+
self.logger.info(f"Processing road user list data from sensor: {data.sensor_name}")
|
|
28
|
+
start = time.time()
|
|
29
|
+
road_user_list = data.road_user_list
|
|
30
|
+
result = self.state_estimator.estimate(road_user_list)
|
|
31
|
+
# road_user_list_data = RoadUserListData(
|
|
32
|
+
# road_user_list=result,
|
|
33
|
+
# capture_timestamp=data.capture_timestamp,
|
|
34
|
+
# sensor_name=data.sensor_name
|
|
35
|
+
# )
|
|
36
|
+
data.road_user_list = result
|
|
37
|
+
self.logger.info(f"State Estimation completed in {time.time() - start:.2f} seconds for sensor: {data.sensor_name}")
|
|
38
|
+
return data
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from msight_core.nodes import DataProcessingNode, NodeConfig
|
|
2
|
+
from msight_core.data import RoadUserListData
|
|
3
|
+
from .. import SortTracker
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
import yaml
|
|
6
|
+
import time
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class SortTrackerNode(DataProcessingNode):
|
|
10
|
+
default_configs = NodeConfig(
|
|
11
|
+
publish_topic_data_type=RoadUserListData
|
|
12
|
+
)
|
|
13
|
+
def __init__(self, configs, tracking_configs_path):
|
|
14
|
+
super().__init__(configs)
|
|
15
|
+
self.config_file_path = Path(tracking_configs_path)
|
|
16
|
+
with open(self.config_file_path, "r") as f:
|
|
17
|
+
self.tracking_configs = yaml.safe_load(f)
|
|
18
|
+
self.tracker = SortTracker()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def process(self, data: RoadUserListData) -> RoadUserListData:
|
|
22
|
+
self.logger.info(f"Processing road user list data from sensor: {data.sensor_name}")
|
|
23
|
+
start = time.time()
|
|
24
|
+
road_user_list = data.road_user_list
|
|
25
|
+
# print(road_user_list)
|
|
26
|
+
tracking_result = self.tracker.track(road_user_list)
|
|
27
|
+
data.road_user_list = tracking_result
|
|
28
|
+
self.logger.info(f"Tracking completed in {time.time() - start:.2f} seconds for sensor: {data.sensor_name}")
|
|
29
|
+
return data
|
|
30
|
+
|
|
31
|
+
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from msight_core.nodes import SinkNode
|
|
2
|
+
from msight_core.data import RoadUserListData, DetectionResultsData
|
|
3
|
+
from msight_base.visualizer import Visualizer
|
|
4
|
+
# from pathlib import Path
|
|
5
|
+
# import yaml
|
|
6
|
+
from msight_base import Frame
|
|
7
|
+
import cv2
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class RoadUserListViewerNode(SinkNode):
|
|
11
|
+
def __init__(self, configs, basemap_path, with_traj=True, show_heading=False):
|
|
12
|
+
super().__init__(configs)
|
|
13
|
+
self.basemap_path = basemap_path
|
|
14
|
+
self.visualizer = Visualizer(basemap_path)
|
|
15
|
+
self.with_traj = with_traj
|
|
16
|
+
self.show_heading = show_heading
|
|
17
|
+
self.step=0
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def on_message(self, data: RoadUserListData):
|
|
21
|
+
self.logger.info(f"Received road user list data from sensor: {data.sensor_name}")
|
|
22
|
+
road_user_list = data.road_user_list
|
|
23
|
+
if not self.with_traj:
|
|
24
|
+
# print(f"fused result: {len(fused_result)} objects")
|
|
25
|
+
vis_image = self.visualizer.render_roaduser_points(road_user_list)
|
|
26
|
+
cv2.imshow(self.name, vis_image)
|
|
27
|
+
cv2.waitKey(1)
|
|
28
|
+
return
|
|
29
|
+
result_frame = Frame(self.step)
|
|
30
|
+
for obj in road_user_list:
|
|
31
|
+
# print(obj.traj_id)
|
|
32
|
+
result_frame.add_object(obj)
|
|
33
|
+
vis_img = self.visualizer.render(result_frame, with_traj=self.with_traj, show_heading=self.show_heading)
|
|
34
|
+
cv2.imshow(self.name, vis_img)
|
|
35
|
+
cv2.waitKey(1)
|
|
36
|
+
self.step += 1
|
|
37
|
+
|
|
38
|
+
class DetectionResults2DViewerNode(SinkNode):
|
|
39
|
+
def __init__(self, configs):
|
|
40
|
+
super().__init__(configs)
|
|
41
|
+
|
|
42
|
+
def on_message(self, data: DetectionResultsData):
|
|
43
|
+
# print(data)
|
|
44
|
+
self.logger.info(f"Received detection results from sensor: {data.sensor_name}")
|
|
45
|
+
raw_image_data = data.raw_sensor_data
|
|
46
|
+
decoded_image = raw_image_data.to_ndarray()
|
|
47
|
+
detection_result = data.detection_result
|
|
48
|
+
for obj in detection_result.object_list:
|
|
49
|
+
# print(obj.box, obj.pixel_bottom_center)
|
|
50
|
+
x1, y1, x2, y2 = map(int, obj.box)
|
|
51
|
+
px, py = map(int, obj.pixel_bottom_center)
|
|
52
|
+
cv2.circle(decoded_image, (px, py), 5, (0, 0, 255), -1)
|
|
53
|
+
cv2.rectangle(decoded_image, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
|
54
|
+
cv2.imshow(self.name, decoded_image)
|
|
55
|
+
cv2.waitKey(1)
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
from msight_core.nodes import SinkNode, NodeConfig
|
|
2
|
+
from msight_core.data import ImageData
|
|
3
|
+
import cv2
|
|
4
|
+
import numpy as np
|
|
5
|
+
import time
|
|
6
|
+
from msight_core.utils import get_redis_client
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
import yaml
|
|
9
|
+
import threading
|
|
10
|
+
|
|
11
|
+
class WarperMatrixUpdaterNode(SinkNode):
|
|
12
|
+
def __init__(self, configs, warper_configs_path):
|
|
13
|
+
'''
|
|
14
|
+
The constructor for the WarperMatrixUpdaterNode class.
|
|
15
|
+
:param standard_image_path: Path to the standard image used for warping.
|
|
16
|
+
:param update_interval: Interval in steps between the updates.
|
|
17
|
+
:param time_threshold: Time threshold in seconds to trigger an update.
|
|
18
|
+
:param redis_prefix: Redis prefix for storing the warp matrix.
|
|
19
|
+
'''
|
|
20
|
+
super().__init__(configs)
|
|
21
|
+
self.warper_config_path = Path(warper_configs_path)
|
|
22
|
+
self.detector = None
|
|
23
|
+
with open(self.warper_config_path, "r") as f:
|
|
24
|
+
self.warper_config = yaml.safe_load(f)
|
|
25
|
+
self.update_interval = self.warper_config["warper_config"]["update_interval"]
|
|
26
|
+
self.time_threshold = self.warper_config["warper_config"]["time_threshold"]
|
|
27
|
+
self.redis_prefix = self.warper_config["warper_config"]["redis_prefix"]
|
|
28
|
+
standard_images_paths = self.warper_config["warper_config"]["std_imgs"]
|
|
29
|
+
self.standard_images = {sensor_name: cv2.imread(img_path) for (sensor_name, img_path) in standard_images_paths.items()}
|
|
30
|
+
self.steps = {sensor_name: 0 for sensor_name in standard_images_paths}
|
|
31
|
+
self.last_update_times = {sensor_name: time.time() for sensor_name in standard_images_paths}
|
|
32
|
+
|
|
33
|
+
def get_warp_matrix_between_two_image(self, im1, im2):
|
|
34
|
+
|
|
35
|
+
# Convert images to grayscale
|
|
36
|
+
im1_gray = cv2.cvtColor(im1, cv2.COLOR_BGR2GRAY)
|
|
37
|
+
im2_gray = cv2.cvtColor(im2, cv2.COLOR_BGR2GRAY)
|
|
38
|
+
|
|
39
|
+
# Find size of image1
|
|
40
|
+
# sz = im1.shape
|
|
41
|
+
|
|
42
|
+
# Define the motion model
|
|
43
|
+
warp_mode = cv2.MOTION_HOMOGRAPHY
|
|
44
|
+
|
|
45
|
+
# Define 2x3 or 3x3 matrices and initialize the matrix to identity
|
|
46
|
+
if warp_mode == cv2.MOTION_HOMOGRAPHY:
|
|
47
|
+
warp_matrix = np.eye(3, 3, dtype=np.float32)
|
|
48
|
+
else:
|
|
49
|
+
warp_matrix = np.eye(2, 3, dtype=np.float32)
|
|
50
|
+
|
|
51
|
+
# Specify the number of iterations.
|
|
52
|
+
number_of_iterations = 500
|
|
53
|
+
|
|
54
|
+
# Specify the threshold of the increment
|
|
55
|
+
# in the correlation coefficient between two iterations
|
|
56
|
+
termination_eps = 1e-10
|
|
57
|
+
|
|
58
|
+
# Define termination criteria
|
|
59
|
+
criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT,
|
|
60
|
+
number_of_iterations, termination_eps)
|
|
61
|
+
|
|
62
|
+
# Run the ECC algorithm. The results are stored in warp_matrix.
|
|
63
|
+
(cc, warp_matrix) = cv2.findTransformECC(
|
|
64
|
+
im1_gray, im2_gray, warp_matrix, warp_mode, criteria)
|
|
65
|
+
return warp_matrix
|
|
66
|
+
|
|
67
|
+
def update_warp_matrix(self, image, sensor_name):
|
|
68
|
+
standard_img = self.standard_images[sensor_name]
|
|
69
|
+
warp_matrix = self.get_warp_matrix_between_two_image(
|
|
70
|
+
standard_img, image)
|
|
71
|
+
return warp_matrix
|
|
72
|
+
|
|
73
|
+
def update_warp_matrix_in_redis(self, warp_matrix, sensor_name):
|
|
74
|
+
redis_client = get_redis_client()
|
|
75
|
+
# Convert the warp matrix to a string and store it in Redis
|
|
76
|
+
warp_matrix_str = np.array2string(warp_matrix, separator=',')
|
|
77
|
+
redis_client.set(self.redis_prefix + f":{sensor_name}", warp_matrix_str)
|
|
78
|
+
# self.logger.info("Warp matrix updated in Redis.")
|
|
79
|
+
|
|
80
|
+
def on_message(self, data: ImageData):
|
|
81
|
+
sensor_name = data.sensor_name
|
|
82
|
+
self.logger.debug(f"receive one image for {sensor_name}")
|
|
83
|
+
# self.logger.info(f"step {self.steps[sensor_name]}")
|
|
84
|
+
if self.steps[sensor_name] % self.update_interval == 0 or time.time() - self.last_update_times[sensor_name] > self.time_threshold:
|
|
85
|
+
def _update():
|
|
86
|
+
self.logger.info(f"Updating parameter for {sensor_name}")
|
|
87
|
+
start = time.time()
|
|
88
|
+
encoded_image = data.encoded_image
|
|
89
|
+
image = cv2.imdecode(encoded_image, cv2.IMREAD_COLOR)
|
|
90
|
+
warp_matrix = self.update_warp_matrix(image, sensor_name)
|
|
91
|
+
self.update_warp_matrix_in_redis(warp_matrix, sensor_name)
|
|
92
|
+
self.logger.info(f"Updated warp matrix for sensor: {sensor_name} in {time.time() - start:.2f} seconds")
|
|
93
|
+
x = threading.Thread(target=_update)
|
|
94
|
+
x.setDaemon(True)
|
|
95
|
+
x.start()
|
|
96
|
+
self.steps[sensor_name] += 1
|
|
97
|
+
self.steps[sensor_name] = self.steps[sensor_name] % self.update_interval
|
|
98
|
+
self.last_update_times[sensor_name] = time.time()
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
from msight_base import RoadUserPoint, TrajectoryManager
|
|
3
|
+
from geopy.distance import geodesic
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class StateEstimatorBase:
|
|
8
|
+
def estimate(self, road_user_point_list: List[RoadUserPoint]) -> List[RoadUserPoint]:
|
|
9
|
+
"""
|
|
10
|
+
Estimate the state of road users based on the provided list of RoadUserPoint instances.
|
|
11
|
+
:param road_user_point_list: List of RoadUserPoint instances to estimate the state from.
|
|
12
|
+
:return: Estimated state of road users.
|
|
13
|
+
"""
|
|
14
|
+
raise NotImplementedError("StateEstimatorBase is an abstract class and cannot be instantiated directly.")
|
|
15
|
+
|
|
16
|
+
class FiniteDifferenceStateEstimator(StateEstimatorBase):
|
|
17
|
+
def __init__(self, frame_rate=5, frame_interval=1, dist_threshold=2):
|
|
18
|
+
"""
|
|
19
|
+
Initialize the finite difference state estimator.
|
|
20
|
+
:param frame_rate: Frame rate of the video stream.
|
|
21
|
+
:param frame_interval: Interval the two object to calculate the difference, the two neighbor objects has interval 0.
|
|
22
|
+
:param heading_mask_area: Area to mask the heading estimation (optional).
|
|
23
|
+
"""
|
|
24
|
+
self.frame_rate = frame_rate
|
|
25
|
+
self.frame_interval = frame_interval
|
|
26
|
+
self.trajectory_manager = TrajectoryManager(max_frames=100)
|
|
27
|
+
self.dist_threshold = dist_threshold
|
|
28
|
+
|
|
29
|
+
# def calc_heading(self, obj: RoadUserPoint, anchor: RoadUserPoint, scale):
|
|
30
|
+
|
|
31
|
+
def get_anchor_point(self, obj):
|
|
32
|
+
traj = obj.traj
|
|
33
|
+
if len(traj.steps) <= 1:
|
|
34
|
+
None
|
|
35
|
+
if len(traj.steps) < self.frame_interval + 2:
|
|
36
|
+
anchor_index = 0
|
|
37
|
+
else:
|
|
38
|
+
anchor_index = len(traj.steps) - self.frame_interval - 1
|
|
39
|
+
anchor_step = traj.steps[anchor_index]
|
|
40
|
+
anchor = traj.step_to_object_map[anchor_step]
|
|
41
|
+
return anchor
|
|
42
|
+
|
|
43
|
+
def calc_xy_difference(self, obj: RoadUserPoint, anchor: RoadUserPoint, scale="latlon"):
|
|
44
|
+
"""
|
|
45
|
+
Calculate the difference between the object and the anchor point in the specified scale, sign is persistent.
|
|
46
|
+
:param
|
|
47
|
+
obj: RoadUserPoint instance representing the object.
|
|
48
|
+
:param anchor: RoadUserPoint instance representing the anchor point.
|
|
49
|
+
:param scale: Scale of the coordinates, either "latlon", "utm" or "meters". ("meters" and "utm" are equivalent)
|
|
50
|
+
:return: Difference between the object and the anchor point in meters.
|
|
51
|
+
"""
|
|
52
|
+
if scale == "latlon":
|
|
53
|
+
lat1, lon1 = obj.x, obj.y
|
|
54
|
+
lat2, lon2 = anchor.x, anchor.y
|
|
55
|
+
# Calculate the distance in meters using geodesic
|
|
56
|
+
dx = geodesic((lat1, lon1), (lat2, lon1)).meters
|
|
57
|
+
if lat1 < lat2:
|
|
58
|
+
dx = -dx
|
|
59
|
+
dy = geodesic((lat1, lon1), (lat1, lon2)).meters
|
|
60
|
+
if lon1 < lon2:
|
|
61
|
+
dy = -dy
|
|
62
|
+
elif scale in ["utm", "meters"]:
|
|
63
|
+
# just take difference
|
|
64
|
+
dx = obj.x - anchor.x
|
|
65
|
+
dy = obj.y - anchor.y
|
|
66
|
+
else:
|
|
67
|
+
raise ValueError("Invalid scale. Use 'latlon', 'utm' or 'meters'.")
|
|
68
|
+
return dx, dy
|
|
69
|
+
|
|
70
|
+
def calc_heading(self, dx: float, dy: float, temporal_distance, fallback) -> float:
|
|
71
|
+
"""
|
|
72
|
+
Calculate the heading of the object based on the difference in x and y coordinates.
|
|
73
|
+
:param
|
|
74
|
+
dx: Difference in x coordinates.
|
|
75
|
+
:param dy: Difference in y coordinates.
|
|
76
|
+
:param temporal_distance: Temporal distance between the two points.
|
|
77
|
+
:return: Heading of the object in degree.
|
|
78
|
+
"""
|
|
79
|
+
# clock-wise, in degree
|
|
80
|
+
# north: 0, east: 90, north-east: 45
|
|
81
|
+
if temporal_distance > self.dist_threshold:
|
|
82
|
+
return fallback
|
|
83
|
+
heading = np.arctan2(dy, dx) * 180 / np.pi
|
|
84
|
+
return heading
|
|
85
|
+
|
|
86
|
+
def calc_speed(self, obj: RoadUserPoint, anchor: RoadUserPoint, temporal_distance) -> float:
|
|
87
|
+
"""
|
|
88
|
+
Calculate the speed of the object based on the difference in x and y coordinates and temporal distance.
|
|
89
|
+
:param
|
|
90
|
+
obj: RoadUserPoint instance representing the object.
|
|
91
|
+
:param anchor: RoadUserPoint instance representing the anchor point.
|
|
92
|
+
:param temporal_distance: Temporal distance between the two points.
|
|
93
|
+
:return: Speed of the object in meters per second.
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
step_now = obj.frame_step
|
|
97
|
+
step_anchor = anchor.frame_step
|
|
98
|
+
time_difference = (step_now - step_anchor) * 1 / self.frame_rate # in seconds
|
|
99
|
+
speed = (temporal_distance / time_difference) if time_difference > 0 else 0.0
|
|
100
|
+
return speed
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def estimate(self, road_user_point_list: List[RoadUserPoint], scale="latlon") -> List[RoadUserPoint]:
|
|
105
|
+
"""
|
|
106
|
+
Estimate the state of road users based on the provided list of RoadUserPoint instances.
|
|
107
|
+
:param road_user_point_list: List of RoadUserPoint instances to estimate the state from.
|
|
108
|
+
:param scale: Scale of the coordinates, either "latlon", "utm" or "meters". ("meters" and "utm" are equivalent)
|
|
109
|
+
:return: Estimated state of road users.
|
|
110
|
+
"""
|
|
111
|
+
self.trajectory_manager.add_list_as_new_frame(road_user_point_list)
|
|
112
|
+
for obj in road_user_point_list:
|
|
113
|
+
anchor = self.get_anchor_point(obj)
|
|
114
|
+
if anchor is None:
|
|
115
|
+
continue
|
|
116
|
+
dx, dy = self.calc_xy_difference(obj, anchor, scale)
|
|
117
|
+
temporal_distance = (dx**2 + dy**2)**0.5
|
|
118
|
+
obj.heading = self.calc_heading(dx, dy, temporal_distance, anchor.heading)
|
|
119
|
+
obj.speed = self.calc_speed(obj, anchor, temporal_distance)
|
|
120
|
+
|
|
121
|
+
return road_user_point_list
|