msight-vision 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,32 @@
1
+ class LocalizerBase:
2
+ def __init__(self):
3
+ pass
4
+
5
+ def localize(self):
6
+ raise NotImplementedError(
7
+ "This method should be overridden by subclasses")
8
+
9
+
10
+ class HashLocalizer(LocalizerBase):
11
+ """Hash-based localizer for 2D images.
12
+ This localizer looks up the pixel values in a hash table to find the corresponding location in the image.
13
+ """
14
+
15
+ def __init__(self, lat_map, lon_map):
16
+ super().__init__()
17
+ self.lat_map = lat_map
18
+ self.lon_map = lon_map
19
+
20
+ def localize(self, detection2d_result):
21
+ # Perform localization using the model and config
22
+ # This is a placeholder implementation
23
+ for obj in detection2d_result.object_list:
24
+ # Assuming obj has a method to get the pixel coordinates
25
+ bottom_center_x = int(obj.pixel_bottom_center[0])
26
+ bottom_center_y = int(obj.pixel_bottom_center[1])
27
+ lat = self.lat_map[bottom_center_y, bottom_center_x,]
28
+ lon = self.lon_map[bottom_center_y, bottom_center_x,]
29
+ obj.lat = lat
30
+ obj.lon = lon
31
+ # print(f"Object {obj.class_id} localized to lat: {lat}, lon: {lon}")
32
+ return detection2d_result
@@ -0,0 +1,6 @@
1
+ from .detection import YoloOneStageDetectionNode
2
+ from .fusion import FuserNode
3
+ from .tracking import SortTrackerNode
4
+ from .viewer import RoadUserListViewerNode, DetectionResults2DViewerNode
5
+ from .warper import WarperMatrixUpdaterNode
6
+ from .state_estimation import FiniteDifferenceStateEstimatorNode
@@ -0,0 +1,103 @@
1
+ from msight_core.nodes import DataProcessingNode, NodeConfig
2
+ from msight_core.data import ImageData, DetectionResultsData
3
+ import yaml
4
+ from pathlib import Path
5
+ import numpy as np
6
+ from .. import YoloDetector, HashLocalizer, ClassicWarperWithExternalUpdate
7
+ import torch
8
+ import time
9
+ from msight_core.utils import get_redis_client
10
+
11
+
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+
14
+ def load_locmaps(loc_maps_path):
15
+ """
16
+ Load localization maps from the specified path.
17
+ :param loc_maps_path: path to the localization maps in the config file
18
+ :return: localization maps
19
+ """
20
+ result = {key: np.load(item) for key, item in loc_maps_path.items()}
21
+ return result
22
+
23
+ class YoloOneStageDetectionNode(DataProcessingNode):
24
+ default_configs = NodeConfig(
25
+ publish_topic_data_type=DetectionResultsData
26
+ )
27
+ def __init__(self, configs, det_configs_path):
28
+ super().__init__(configs)
29
+ self.det_config_path = Path(det_configs_path)
30
+ self.detector = None
31
+ with open(self.det_config_path, "r") as f:
32
+ self.det_config = yaml.safe_load(f)
33
+ self.no_warp = self.det_config['warper_config']['no_warp']
34
+ self.model_path = self.det_config["model_config"]["ckpt_path"]
35
+ self.confthre = self.det_config["model_config"]["confthre"]
36
+ self.nmsthre = self.det_config["model_config"]["nmsthre"]
37
+ self.class_agnostic_nms = self.det_config["model_config"]["class_agnostic_nms"]
38
+ self.logger.info(f"Initializing YoloOneStageDetectionNode with model path: {self.model_path}, no_warp: {self.no_warp}, confthre: {self.confthre}, nmsthre: {self.nmsthre}, class_agnostic_nms: {self.class_agnostic_nms}")
39
+ self.detector = YoloDetector(model_path=Path(self.model_path), device=device, confthre=self.confthre, nmsthre=self.nmsthre, fp16=False, class_agnostic_nms=self.class_agnostic_nms)
40
+ loc_maps_path = self.det_config["loc_maps"]
41
+ loc_maps = load_locmaps(loc_maps_path)
42
+ self.localizers = {key: HashLocalizer(item['x_map'], item['y_map']) for key, item in loc_maps.items()}
43
+ if not self.no_warp:
44
+ self.warper = ClassicWarperWithExternalUpdate()
45
+ self.warper_matrix_redis_prefix = self.det_config["warper_config"]["redis_prefix"]
46
+ else:
47
+ self.warper_matrix_redis_prefix = None
48
+ self.include_sensor_data_in_result = self.det_config["det_config"].get("include_sensor_data_in_result", False)
49
+ self.sensor_type = self.det_config["det_config"].get("sensor_type", "fisheye")
50
+
51
+ def get_warp_matrix_from_redis(self, sensor_name):
52
+ redis_client = get_redis_client()
53
+ key = self.warper_matrix_redis_prefix + f":{sensor_name}"
54
+ warp_matrix_str = redis_client.get(key)
55
+
56
+ if warp_matrix_str is None:
57
+ self.logger.warning(f"No warp matrix found in Redis for sensor: {sensor_name}")
58
+ return None
59
+
60
+ # Decode bytes to string if needed
61
+ if isinstance(warp_matrix_str, bytes):
62
+ warp_matrix_str = warp_matrix_str.decode()
63
+
64
+ # Remove brackets and convert back to numpy array
65
+ warp_matrix = np.array(eval(warp_matrix_str))
66
+
67
+ # You may need to reshape it, for example to 3x3 if that's your matrix size
68
+ warp_matrix = warp_matrix.reshape((3, 3)) # adjust the shape if your matrix is different
69
+
70
+ return warp_matrix
71
+
72
+
73
+ def process(self, data: ImageData):
74
+ self.logger.info(f"Processing image data from sensor: {data.sensor_name}, frame: {data.frame_id}")
75
+ start = time.time()
76
+ image = data.to_ndarray()
77
+ sensor_name = data.sensor_name
78
+ frame_id = data.frame_id
79
+ # cv2.imshow("image", image)
80
+ # cv2.waitKey(1)
81
+ # print(image.shape)
82
+ timestamp = data.capture_timestamp
83
+ if not self.no_warp:
84
+ # print(f"Image shape before warping: {image.shape}")
85
+ warping_matrix = self.get_warp_matrix_from_redis(sensor_name)
86
+ image = self.warper.warp(image, warping_matrix)
87
+ # print(f"Image shape after warping: {image.shape}")
88
+ # cv2.imshow("image", image)
89
+ # cv2.waitKey(1)
90
+ # print(image.shape)
91
+ result = self.detector.detect(image, timestamp, self.sensor_type)
92
+ localizer = self.localizers[sensor_name]
93
+ localizer.localize(result)
94
+ self.logger.info(f"Detection completed in {time.time() - start:.2f} seconds for sensor: {sensor_name}")
95
+ def is_number(val):
96
+ return isinstance(val, (int, float, np.number)) and np.isfinite(val)
97
+ result.object_list = [obj for obj in result.object_list if is_number(obj.lat) and is_number(obj.lon)]
98
+ raw_sensor_data = None
99
+ if self.include_sensor_data_in_result:
100
+ raw_sensor_data = data
101
+ detection_result_data = DetectionResultsData(result, sensor_frame_id=data.frame_id, capture_timestamp=data.capture_timestamp, creation_timestamp=time.time(), sensor_name=sensor_name, raw_sensor_data=raw_sensor_data)
102
+ # print(f"Detection results: {detection_result_data.to_dict()}")
103
+ return detection_result_data
@@ -0,0 +1,64 @@
1
+ from msight_core.nodes import DataProcessingNode, NodeConfig
2
+ from msight_core.data import RoadUserListData, DetectionResultsData
3
+ from pathlib import Path
4
+ import yaml
5
+ import importlib.util
6
+ import sys
7
+ import copy
8
+
9
+ def load_class_from_file(file_path: str, class_path: str):
10
+ module_name = Path(file_path).stem
11
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
12
+ if spec is None:
13
+ raise ImportError(f"Cannot load module from {file_path}")
14
+ module = importlib.util.module_from_spec(spec)
15
+ sys.modules[module_name] = module
16
+ spec.loader.exec_module(module)
17
+
18
+ # Resolve class from class_path
19
+ components = class_path.split(".")
20
+ cls = module
21
+ for comp in components[1:]: # Skip the module name
22
+ cls = getattr(cls, comp)
23
+ return cls
24
+
25
+ class FuserNode(DataProcessingNode):
26
+ default_configs = NodeConfig(
27
+ publish_topic_data_type=RoadUserListData
28
+ )
29
+ def __init__(self, configs, fusion_configs_path):
30
+ super().__init__(configs)
31
+ self.config_file_path = Path(fusion_configs_path)
32
+ with open(self.config_file_path, "r") as f:
33
+ self.fusion_configs = yaml.safe_load(f)
34
+ self.fuser_class_path = self.fusion_configs["fuser_config"]["class_path"]
35
+ self.fuser_file_path = self.fusion_configs["fuser_config"]["file_path"]
36
+ self.sensor_list = self.fusion_configs["fuser_config"]["sensor_list"]
37
+ FuserClass = load_class_from_file(self.fuser_file_path, self.fuser_class_path)
38
+ self.fuser = FuserClass()
39
+ self.buffer = {sensor: None for sensor in self.sensor_list}
40
+ assert configs.sensor_name is not None, "sensor_name must be provided in configs for the fusion node."
41
+ self.sensor_name = self.configs.sensor_name
42
+
43
+ def process(self, data: DetectionResultsData):
44
+ self.buffer[data.sensor_name] = data.detection_result
45
+ # print(data.detection_result)
46
+ buffer = copy.copy(self.buffer)
47
+ self.logger.info(f"Processing detection results from sensor: {data.sensor_name}")
48
+ sensor_name = data.sensor_name
49
+ if sensor_name not in self.sensor_list:
50
+ raise ValueError(f"Sensor {sensor_name} not in configured sensor list: {self.sensor_list}")
51
+
52
+ for _, detection_result in buffer.items():
53
+ if detection_result is None:
54
+ return None
55
+ self.logger.info(f"All data received. Fusing detection results for sensor.")
56
+ fused_result = self.fuser.fuse(buffer)
57
+ road_user_list_data = RoadUserListData(
58
+ road_user_list=fused_result,
59
+ capture_timestamp=data.capture_timestamp,
60
+ sensor_name=self.sensor_name
61
+ )
62
+ self.buffer = {sensor: None for sensor in self.sensor_list} # Reset buffer after fusion
63
+ # print(road_user_list_data)
64
+ return road_user_list_data
@@ -0,0 +1,38 @@
1
+ from msight_core.nodes import DataProcessingNode, NodeConfig
2
+ from msight_core.data import RoadUserListData
3
+ from .. import FiniteDifferenceStateEstimator
4
+ from pathlib import Path
5
+ import yaml
6
+ import time
7
+
8
+ class FiniteDifferenceStateEstimatorNode(DataProcessingNode):
9
+ default_configs = NodeConfig(
10
+ publish_topic_data_type=RoadUserListData
11
+ )
12
+ def __init__(self, configs, state_estimation_configs_path):
13
+ super().__init__(configs)
14
+ self.config_file_path = Path(state_estimation_configs_path)
15
+ with open(self.config_file_path, "r") as f:
16
+ self.state_estimation_configs = yaml.safe_load(f)
17
+ self.frame_rate = self.state_estimation_configs["state_estimator_config"].get("frame_rate", 5)
18
+ self.frame_interval = self.state_estimation_configs["state_estimator_config"].get("frame_interval", 1)
19
+ self.dist_threshold = self.state_estimation_configs["state_estimator_config"].get("dist_threshold", 4)
20
+ self.state_estimator = FiniteDifferenceStateEstimator(
21
+ frame_rate=self.frame_rate,
22
+ frame_interval=self.frame_interval,
23
+ dist_threshold=self.dist_threshold
24
+ )
25
+
26
+ def process(self, data: RoadUserListData) -> RoadUserListData:
27
+ self.logger.info(f"Processing road user list data from sensor: {data.sensor_name}")
28
+ start = time.time()
29
+ road_user_list = data.road_user_list
30
+ result = self.state_estimator.estimate(road_user_list)
31
+ # road_user_list_data = RoadUserListData(
32
+ # road_user_list=result,
33
+ # capture_timestamp=data.capture_timestamp,
34
+ # sensor_name=data.sensor_name
35
+ # )
36
+ data.road_user_list = result
37
+ self.logger.info(f"State Estimation completed in {time.time() - start:.2f} seconds for sensor: {data.sensor_name}")
38
+ return data
@@ -0,0 +1,31 @@
1
+ from msight_core.nodes import DataProcessingNode, NodeConfig
2
+ from msight_core.data import RoadUserListData
3
+ from .. import SortTracker
4
+ from pathlib import Path
5
+ import yaml
6
+ import time
7
+
8
+
9
+ class SortTrackerNode(DataProcessingNode):
10
+ default_configs = NodeConfig(
11
+ publish_topic_data_type=RoadUserListData
12
+ )
13
+ def __init__(self, configs, tracking_configs_path):
14
+ super().__init__(configs)
15
+ self.config_file_path = Path(tracking_configs_path)
16
+ with open(self.config_file_path, "r") as f:
17
+ self.tracking_configs = yaml.safe_load(f)
18
+ self.tracker = SortTracker()
19
+
20
+
21
+ def process(self, data: RoadUserListData) -> RoadUserListData:
22
+ self.logger.info(f"Processing road user list data from sensor: {data.sensor_name}")
23
+ start = time.time()
24
+ road_user_list = data.road_user_list
25
+ # print(road_user_list)
26
+ tracking_result = self.tracker.track(road_user_list)
27
+ data.road_user_list = tracking_result
28
+ self.logger.info(f"Tracking completed in {time.time() - start:.2f} seconds for sensor: {data.sensor_name}")
29
+ return data
30
+
31
+
@@ -0,0 +1,55 @@
1
+ from msight_core.nodes import SinkNode
2
+ from msight_core.data import RoadUserListData, DetectionResultsData
3
+ from msight_base.visualizer import Visualizer
4
+ # from pathlib import Path
5
+ # import yaml
6
+ from msight_base import Frame
7
+ import cv2
8
+
9
+
10
+ class RoadUserListViewerNode(SinkNode):
11
+ def __init__(self, configs, basemap_path, with_traj=True, show_heading=False):
12
+ super().__init__(configs)
13
+ self.basemap_path = basemap_path
14
+ self.visualizer = Visualizer(basemap_path)
15
+ self.with_traj = with_traj
16
+ self.show_heading = show_heading
17
+ self.step=0
18
+
19
+
20
+ def on_message(self, data: RoadUserListData):
21
+ self.logger.info(f"Received road user list data from sensor: {data.sensor_name}")
22
+ road_user_list = data.road_user_list
23
+ if not self.with_traj:
24
+ # print(f"fused result: {len(fused_result)} objects")
25
+ vis_image = self.visualizer.render_roaduser_points(road_user_list)
26
+ cv2.imshow(self.name, vis_image)
27
+ cv2.waitKey(1)
28
+ return
29
+ result_frame = Frame(self.step)
30
+ for obj in road_user_list:
31
+ # print(obj.traj_id)
32
+ result_frame.add_object(obj)
33
+ vis_img = self.visualizer.render(result_frame, with_traj=self.with_traj, show_heading=self.show_heading)
34
+ cv2.imshow(self.name, vis_img)
35
+ cv2.waitKey(1)
36
+ self.step += 1
37
+
38
+ class DetectionResults2DViewerNode(SinkNode):
39
+ def __init__(self, configs):
40
+ super().__init__(configs)
41
+
42
+ def on_message(self, data: DetectionResultsData):
43
+ # print(data)
44
+ self.logger.info(f"Received detection results from sensor: {data.sensor_name}")
45
+ raw_image_data = data.raw_sensor_data
46
+ decoded_image = raw_image_data.to_ndarray()
47
+ detection_result = data.detection_result
48
+ for obj in detection_result.object_list:
49
+ # print(obj.box, obj.pixel_bottom_center)
50
+ x1, y1, x2, y2 = map(int, obj.box)
51
+ px, py = map(int, obj.pixel_bottom_center)
52
+ cv2.circle(decoded_image, (px, py), 5, (0, 0, 255), -1)
53
+ cv2.rectangle(decoded_image, (x1, y1), (x2, y2), (0, 255, 0), 2)
54
+ cv2.imshow(self.name, decoded_image)
55
+ cv2.waitKey(1)
@@ -0,0 +1,98 @@
1
+ from msight_core.nodes import SinkNode, NodeConfig
2
+ from msight_core.data import ImageData
3
+ import cv2
4
+ import numpy as np
5
+ import time
6
+ from msight_core.utils import get_redis_client
7
+ from pathlib import Path
8
+ import yaml
9
+ import threading
10
+
11
+ class WarperMatrixUpdaterNode(SinkNode):
12
+ def __init__(self, configs, warper_configs_path):
13
+ '''
14
+ The constructor for the WarperMatrixUpdaterNode class.
15
+ :param standard_image_path: Path to the standard image used for warping.
16
+ :param update_interval: Interval in steps between the updates.
17
+ :param time_threshold: Time threshold in seconds to trigger an update.
18
+ :param redis_prefix: Redis prefix for storing the warp matrix.
19
+ '''
20
+ super().__init__(configs)
21
+ self.warper_config_path = Path(warper_configs_path)
22
+ self.detector = None
23
+ with open(self.warper_config_path, "r") as f:
24
+ self.warper_config = yaml.safe_load(f)
25
+ self.update_interval = self.warper_config["warper_config"]["update_interval"]
26
+ self.time_threshold = self.warper_config["warper_config"]["time_threshold"]
27
+ self.redis_prefix = self.warper_config["warper_config"]["redis_prefix"]
28
+ standard_images_paths = self.warper_config["warper_config"]["std_imgs"]
29
+ self.standard_images = {sensor_name: cv2.imread(img_path) for (sensor_name, img_path) in standard_images_paths.items()}
30
+ self.steps = {sensor_name: 0 for sensor_name in standard_images_paths}
31
+ self.last_update_times = {sensor_name: time.time() for sensor_name in standard_images_paths}
32
+
33
+ def get_warp_matrix_between_two_image(self, im1, im2):
34
+
35
+ # Convert images to grayscale
36
+ im1_gray = cv2.cvtColor(im1, cv2.COLOR_BGR2GRAY)
37
+ im2_gray = cv2.cvtColor(im2, cv2.COLOR_BGR2GRAY)
38
+
39
+ # Find size of image1
40
+ # sz = im1.shape
41
+
42
+ # Define the motion model
43
+ warp_mode = cv2.MOTION_HOMOGRAPHY
44
+
45
+ # Define 2x3 or 3x3 matrices and initialize the matrix to identity
46
+ if warp_mode == cv2.MOTION_HOMOGRAPHY:
47
+ warp_matrix = np.eye(3, 3, dtype=np.float32)
48
+ else:
49
+ warp_matrix = np.eye(2, 3, dtype=np.float32)
50
+
51
+ # Specify the number of iterations.
52
+ number_of_iterations = 500
53
+
54
+ # Specify the threshold of the increment
55
+ # in the correlation coefficient between two iterations
56
+ termination_eps = 1e-10
57
+
58
+ # Define termination criteria
59
+ criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT,
60
+ number_of_iterations, termination_eps)
61
+
62
+ # Run the ECC algorithm. The results are stored in warp_matrix.
63
+ (cc, warp_matrix) = cv2.findTransformECC(
64
+ im1_gray, im2_gray, warp_matrix, warp_mode, criteria)
65
+ return warp_matrix
66
+
67
+ def update_warp_matrix(self, image, sensor_name):
68
+ standard_img = self.standard_images[sensor_name]
69
+ warp_matrix = self.get_warp_matrix_between_two_image(
70
+ standard_img, image)
71
+ return warp_matrix
72
+
73
+ def update_warp_matrix_in_redis(self, warp_matrix, sensor_name):
74
+ redis_client = get_redis_client()
75
+ # Convert the warp matrix to a string and store it in Redis
76
+ warp_matrix_str = np.array2string(warp_matrix, separator=',')
77
+ redis_client.set(self.redis_prefix + f":{sensor_name}", warp_matrix_str)
78
+ # self.logger.info("Warp matrix updated in Redis.")
79
+
80
+ def on_message(self, data: ImageData):
81
+ sensor_name = data.sensor_name
82
+ self.logger.debug(f"receive one image for {sensor_name}")
83
+ # self.logger.info(f"step {self.steps[sensor_name]}")
84
+ if self.steps[sensor_name] % self.update_interval == 0 or time.time() - self.last_update_times[sensor_name] > self.time_threshold:
85
+ def _update():
86
+ self.logger.info(f"Updating parameter for {sensor_name}")
87
+ start = time.time()
88
+ encoded_image = data.encoded_image
89
+ image = cv2.imdecode(encoded_image, cv2.IMREAD_COLOR)
90
+ warp_matrix = self.update_warp_matrix(image, sensor_name)
91
+ self.update_warp_matrix_in_redis(warp_matrix, sensor_name)
92
+ self.logger.info(f"Updated warp matrix for sensor: {sensor_name} in {time.time() - start:.2f} seconds")
93
+ x = threading.Thread(target=_update)
94
+ x.setDaemon(True)
95
+ x.start()
96
+ self.steps[sensor_name] += 1
97
+ self.steps[sensor_name] = self.steps[sensor_name] % self.update_interval
98
+ self.last_update_times[sensor_name] = time.time()
@@ -0,0 +1,121 @@
1
+ from typing import List
2
+ from msight_base import RoadUserPoint, TrajectoryManager
3
+ from geopy.distance import geodesic
4
+ import numpy as np
5
+
6
+
7
+ class StateEstimatorBase:
8
+ def estimate(self, road_user_point_list: List[RoadUserPoint]) -> List[RoadUserPoint]:
9
+ """
10
+ Estimate the state of road users based on the provided list of RoadUserPoint instances.
11
+ :param road_user_point_list: List of RoadUserPoint instances to estimate the state from.
12
+ :return: Estimated state of road users.
13
+ """
14
+ raise NotImplementedError("StateEstimatorBase is an abstract class and cannot be instantiated directly.")
15
+
16
+ class FiniteDifferenceStateEstimator(StateEstimatorBase):
17
+ def __init__(self, frame_rate=5, frame_interval=1, dist_threshold=2):
18
+ """
19
+ Initialize the finite difference state estimator.
20
+ :param frame_rate: Frame rate of the video stream.
21
+ :param frame_interval: Interval the two object to calculate the difference, the two neighbor objects has interval 0.
22
+ :param heading_mask_area: Area to mask the heading estimation (optional).
23
+ """
24
+ self.frame_rate = frame_rate
25
+ self.frame_interval = frame_interval
26
+ self.trajectory_manager = TrajectoryManager(max_frames=100)
27
+ self.dist_threshold = dist_threshold
28
+
29
+ # def calc_heading(self, obj: RoadUserPoint, anchor: RoadUserPoint, scale):
30
+
31
+ def get_anchor_point(self, obj):
32
+ traj = obj.traj
33
+ if len(traj.steps) <= 1:
34
+ None
35
+ if len(traj.steps) < self.frame_interval + 2:
36
+ anchor_index = 0
37
+ else:
38
+ anchor_index = len(traj.steps) - self.frame_interval - 1
39
+ anchor_step = traj.steps[anchor_index]
40
+ anchor = traj.step_to_object_map[anchor_step]
41
+ return anchor
42
+
43
+ def calc_xy_difference(self, obj: RoadUserPoint, anchor: RoadUserPoint, scale="latlon"):
44
+ """
45
+ Calculate the difference between the object and the anchor point in the specified scale, sign is persistent.
46
+ :param
47
+ obj: RoadUserPoint instance representing the object.
48
+ :param anchor: RoadUserPoint instance representing the anchor point.
49
+ :param scale: Scale of the coordinates, either "latlon", "utm" or "meters". ("meters" and "utm" are equivalent)
50
+ :return: Difference between the object and the anchor point in meters.
51
+ """
52
+ if scale == "latlon":
53
+ lat1, lon1 = obj.x, obj.y
54
+ lat2, lon2 = anchor.x, anchor.y
55
+ # Calculate the distance in meters using geodesic
56
+ dx = geodesic((lat1, lon1), (lat2, lon1)).meters
57
+ if lat1 < lat2:
58
+ dx = -dx
59
+ dy = geodesic((lat1, lon1), (lat1, lon2)).meters
60
+ if lon1 < lon2:
61
+ dy = -dy
62
+ elif scale in ["utm", "meters"]:
63
+ # just take difference
64
+ dx = obj.x - anchor.x
65
+ dy = obj.y - anchor.y
66
+ else:
67
+ raise ValueError("Invalid scale. Use 'latlon', 'utm' or 'meters'.")
68
+ return dx, dy
69
+
70
+ def calc_heading(self, dx: float, dy: float, temporal_distance, fallback) -> float:
71
+ """
72
+ Calculate the heading of the object based on the difference in x and y coordinates.
73
+ :param
74
+ dx: Difference in x coordinates.
75
+ :param dy: Difference in y coordinates.
76
+ :param temporal_distance: Temporal distance between the two points.
77
+ :return: Heading of the object in degree.
78
+ """
79
+ # clock-wise, in degree
80
+ # north: 0, east: 90, north-east: 45
81
+ if temporal_distance > self.dist_threshold:
82
+ return fallback
83
+ heading = np.arctan2(dy, dx) * 180 / np.pi
84
+ return heading
85
+
86
+ def calc_speed(self, obj: RoadUserPoint, anchor: RoadUserPoint, temporal_distance) -> float:
87
+ """
88
+ Calculate the speed of the object based on the difference in x and y coordinates and temporal distance.
89
+ :param
90
+ obj: RoadUserPoint instance representing the object.
91
+ :param anchor: RoadUserPoint instance representing the anchor point.
92
+ :param temporal_distance: Temporal distance between the two points.
93
+ :return: Speed of the object in meters per second.
94
+ """
95
+
96
+ step_now = obj.frame_step
97
+ step_anchor = anchor.frame_step
98
+ time_difference = (step_now - step_anchor) * 1 / self.frame_rate # in seconds
99
+ speed = (temporal_distance / time_difference) if time_difference > 0 else 0.0
100
+ return speed
101
+
102
+
103
+
104
+ def estimate(self, road_user_point_list: List[RoadUserPoint], scale="latlon") -> List[RoadUserPoint]:
105
+ """
106
+ Estimate the state of road users based on the provided list of RoadUserPoint instances.
107
+ :param road_user_point_list: List of RoadUserPoint instances to estimate the state from.
108
+ :param scale: Scale of the coordinates, either "latlon", "utm" or "meters". ("meters" and "utm" are equivalent)
109
+ :return: Estimated state of road users.
110
+ """
111
+ self.trajectory_manager.add_list_as_new_frame(road_user_point_list)
112
+ for obj in road_user_point_list:
113
+ anchor = self.get_anchor_point(obj)
114
+ if anchor is None:
115
+ continue
116
+ dx, dy = self.calc_xy_difference(obj, anchor, scale)
117
+ temporal_distance = (dx**2 + dy**2)**0.5
118
+ obj.heading = self.calc_heading(dx, dy, temporal_distance, anchor.heading)
119
+ obj.speed = self.calc_speed(obj, anchor, temporal_distance)
120
+
121
+ return road_user_point_list