simple-autonomous-car 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- simple_autonomous_car/__init__.py +96 -0
- simple_autonomous_car/alerts/__init__.py +5 -0
- simple_autonomous_car/alerts/track_bounds_alert.py +276 -0
- simple_autonomous_car/car/__init__.py +5 -0
- simple_autonomous_car/car/car.py +234 -0
- simple_autonomous_car/constants.py +112 -0
- simple_autonomous_car/control/__init__.py +7 -0
- simple_autonomous_car/control/base_controller.py +152 -0
- simple_autonomous_car/control/controller_viz.py +282 -0
- simple_autonomous_car/control/pid_controller.py +153 -0
- simple_autonomous_car/control/pure_pursuit_controller.py +578 -0
- simple_autonomous_car/costmap/__init__.py +12 -0
- simple_autonomous_car/costmap/base_costmap.py +187 -0
- simple_autonomous_car/costmap/grid_costmap.py +507 -0
- simple_autonomous_car/costmap/inflation.py +126 -0
- simple_autonomous_car/detection/__init__.py +5 -0
- simple_autonomous_car/detection/error_detector.py +165 -0
- simple_autonomous_car/filters/__init__.py +7 -0
- simple_autonomous_car/filters/base_filter.py +119 -0
- simple_autonomous_car/filters/kalman_filter.py +131 -0
- simple_autonomous_car/filters/particle_filter.py +162 -0
- simple_autonomous_car/footprint/__init__.py +7 -0
- simple_autonomous_car/footprint/base_footprint.py +128 -0
- simple_autonomous_car/footprint/circular_footprint.py +73 -0
- simple_autonomous_car/footprint/rectangular_footprint.py +123 -0
- simple_autonomous_car/frames/__init__.py +21 -0
- simple_autonomous_car/frames/frenet.py +267 -0
- simple_autonomous_car/maps/__init__.py +9 -0
- simple_autonomous_car/maps/frenet_map.py +97 -0
- simple_autonomous_car/maps/grid_ground_truth_map.py +83 -0
- simple_autonomous_car/maps/grid_map.py +361 -0
- simple_autonomous_car/maps/ground_truth_map.py +64 -0
- simple_autonomous_car/maps/perceived_map.py +169 -0
- simple_autonomous_car/perception/__init__.py +5 -0
- simple_autonomous_car/perception/perception.py +107 -0
- simple_autonomous_car/planning/__init__.py +7 -0
- simple_autonomous_car/planning/base_planner.py +184 -0
- simple_autonomous_car/planning/goal_planner.py +261 -0
- simple_autonomous_car/planning/track_planner.py +199 -0
- simple_autonomous_car/sensors/__init__.py +6 -0
- simple_autonomous_car/sensors/base_sensor.py +105 -0
- simple_autonomous_car/sensors/lidar_sensor.py +145 -0
- simple_autonomous_car/track/__init__.py +5 -0
- simple_autonomous_car/track/track.py +463 -0
- simple_autonomous_car/visualization/__init__.py +25 -0
- simple_autonomous_car/visualization/alert_viz.py +316 -0
- simple_autonomous_car/visualization/utils.py +169 -0
- simple_autonomous_car-0.1.2.dist-info/METADATA +324 -0
- simple_autonomous_car-0.1.2.dist-info/RECORD +50 -0
- simple_autonomous_car-0.1.2.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
"""Inflation utilities for costmaps."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from scipy.ndimage import maximum_filter # type: ignore[import-untyped]
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def compute_inflation_kernel(radius: float, resolution: float) -> np.ndarray:
|
|
8
|
+
"""
|
|
9
|
+
Compute inflation kernel for obstacle inflation.
|
|
10
|
+
|
|
11
|
+
Parameters
|
|
12
|
+
----------
|
|
13
|
+
radius : float
|
|
14
|
+
Inflation radius in meters.
|
|
15
|
+
resolution : float
|
|
16
|
+
Costmap resolution in meters per cell.
|
|
17
|
+
|
|
18
|
+
Returns
|
|
19
|
+
-------
|
|
20
|
+
np.ndarray
|
|
21
|
+
2D kernel array for inflation.
|
|
22
|
+
|
|
23
|
+
Examples
|
|
24
|
+
--------
|
|
25
|
+
>>> kernel = compute_inflation_kernel(radius=1.0, resolution=0.5)
|
|
26
|
+
>>> print(kernel.shape)
|
|
27
|
+
(5, 5)
|
|
28
|
+
"""
|
|
29
|
+
kernel_size = int(2 * radius / resolution) + 1
|
|
30
|
+
kernel = np.zeros((kernel_size, kernel_size))
|
|
31
|
+
|
|
32
|
+
center = kernel_size // 2
|
|
33
|
+
for i in range(kernel_size):
|
|
34
|
+
for j in range(kernel_size):
|
|
35
|
+
dx = (i - center) * resolution
|
|
36
|
+
dy = (j - center) * resolution
|
|
37
|
+
distance = np.sqrt(dx**2 + dy**2)
|
|
38
|
+
if distance <= radius:
|
|
39
|
+
# Linear decay from 1.0 at center to 0.0 at radius
|
|
40
|
+
cost = 1.0 - (distance / radius)
|
|
41
|
+
kernel[i, j] = cost
|
|
42
|
+
|
|
43
|
+
return kernel
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def inflate_obstacles(
|
|
47
|
+
costmap: np.ndarray,
|
|
48
|
+
inflation_radius: float,
|
|
49
|
+
resolution: float,
|
|
50
|
+
method: str = "linear",
|
|
51
|
+
) -> np.ndarray:
|
|
52
|
+
"""
|
|
53
|
+
Inflate obstacles in a costmap.
|
|
54
|
+
|
|
55
|
+
Parameters
|
|
56
|
+
----------
|
|
57
|
+
costmap : np.ndarray
|
|
58
|
+
2D costmap array (0.0 = free, 1.0 = occupied).
|
|
59
|
+
inflation_radius : float
|
|
60
|
+
Inflation radius in meters.
|
|
61
|
+
resolution : float
|
|
62
|
+
Costmap resolution in meters per cell.
|
|
63
|
+
method : str, default="linear"
|
|
64
|
+
Inflation method: "linear" (linear decay) or "binary" (binary expansion).
|
|
65
|
+
|
|
66
|
+
Returns
|
|
67
|
+
-------
|
|
68
|
+
np.ndarray
|
|
69
|
+
Inflated costmap.
|
|
70
|
+
|
|
71
|
+
Examples
|
|
72
|
+
--------
|
|
73
|
+
>>> costmap = np.zeros((100, 100))
|
|
74
|
+
>>> costmap[50, 50] = 1.0 # Obstacle
|
|
75
|
+
>>> inflated = inflate_obstacles(costmap, inflation_radius=2.0, resolution=0.5)
|
|
76
|
+
"""
|
|
77
|
+
if method == "linear":
|
|
78
|
+
# More efficient linear inflation using distance transform
|
|
79
|
+
from scipy.ndimage import distance_transform_edt
|
|
80
|
+
|
|
81
|
+
# Create binary obstacle map
|
|
82
|
+
obstacles = costmap >= 0.5
|
|
83
|
+
|
|
84
|
+
# Compute distance to nearest obstacle
|
|
85
|
+
distances = distance_transform_edt(~obstacles) * resolution
|
|
86
|
+
|
|
87
|
+
# Apply linear decay
|
|
88
|
+
result = np.zeros_like(costmap)
|
|
89
|
+
mask = distances <= inflation_radius
|
|
90
|
+
result[mask] = 1.0 - (distances[mask] / inflation_radius)
|
|
91
|
+
result[~mask] = costmap[~mask]
|
|
92
|
+
|
|
93
|
+
# Preserve original obstacles (set to max cost)
|
|
94
|
+
result[obstacles] = 1.0
|
|
95
|
+
|
|
96
|
+
return result
|
|
97
|
+
else: # binary
|
|
98
|
+
kernel_size = int(2 * inflation_radius / resolution) + 1
|
|
99
|
+
kernel = np.ones((kernel_size, kernel_size))
|
|
100
|
+
inflated = maximum_filter(costmap, footprint=kernel)
|
|
101
|
+
return np.asarray(inflated, dtype=costmap.dtype)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def compute_distance_transform(costmap: np.ndarray, resolution: float) -> np.ndarray:
|
|
105
|
+
"""
|
|
106
|
+
Compute distance transform of costmap.
|
|
107
|
+
|
|
108
|
+
Parameters
|
|
109
|
+
----------
|
|
110
|
+
costmap : np.ndarray
|
|
111
|
+
2D costmap array.
|
|
112
|
+
resolution : float
|
|
113
|
+
Costmap resolution in meters per cell.
|
|
114
|
+
|
|
115
|
+
Returns
|
|
116
|
+
-------
|
|
117
|
+
np.ndarray
|
|
118
|
+
Distance transform (distance to nearest obstacle in meters).
|
|
119
|
+
"""
|
|
120
|
+
from scipy.ndimage import distance_transform_edt
|
|
121
|
+
|
|
122
|
+
# Invert: obstacles become 0, free space becomes 1
|
|
123
|
+
inverted = 1.0 - costmap
|
|
124
|
+
# Compute distance transform
|
|
125
|
+
distances = distance_transform_edt(inverted) * resolution
|
|
126
|
+
return np.asarray(distances, dtype=np.float64)
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
"""Localization error detection between perceived and ground truth maps."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from simple_autonomous_car.car.car import CarState
|
|
6
|
+
from simple_autonomous_car.maps.ground_truth_map import GroundTruthMap
|
|
7
|
+
from simple_autonomous_car.maps.perceived_map import PerceivedMap
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class LocalizationErrorDetector:
|
|
11
|
+
"""Detects errors between perceived and ground truth localization."""
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
ground_truth_map: GroundTruthMap,
|
|
16
|
+
perceived_map: PerceivedMap,
|
|
17
|
+
error_threshold: float = 1.0,
|
|
18
|
+
):
|
|
19
|
+
"""
|
|
20
|
+
Initialize error detector.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
ground_truth_map: Ground truth map
|
|
24
|
+
perceived_map: Perceived map
|
|
25
|
+
error_threshold: Distance threshold for error detection (meters)
|
|
26
|
+
"""
|
|
27
|
+
self.ground_truth_map = ground_truth_map
|
|
28
|
+
self.perceived_map = perceived_map
|
|
29
|
+
self.error_threshold = error_threshold
|
|
30
|
+
|
|
31
|
+
def compute_errors(
|
|
32
|
+
self, car_state: CarState, horizon: float, fov: float = 2 * np.pi
|
|
33
|
+
) -> dict[str, np.ndarray]:
|
|
34
|
+
"""
|
|
35
|
+
Compute errors between perceived and ground truth segments.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
car_state: Current car state
|
|
39
|
+
horizon: Maximum distance to consider
|
|
40
|
+
fov: Field of view angle in radians
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
Dictionary with error metrics:
|
|
44
|
+
- centerline_errors: Distance errors for centerline points
|
|
45
|
+
- inner_bound_errors: Distance errors for inner boundary points
|
|
46
|
+
- outer_bound_errors: Distance errors for outer boundary points
|
|
47
|
+
- max_error: Maximum error across all segments
|
|
48
|
+
- mean_error: Mean error across all segments
|
|
49
|
+
"""
|
|
50
|
+
# Update perceived state
|
|
51
|
+
self.perceived_map.update_perceived_state(car_state)
|
|
52
|
+
|
|
53
|
+
# Get ground truth segments in car frame
|
|
54
|
+
gt_centerline, gt_inner, gt_outer = self.ground_truth_map.get_visible_segments(
|
|
55
|
+
car_state.position(), car_state.heading, horizon, fov
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
if len(gt_centerline) == 0:
|
|
59
|
+
return {
|
|
60
|
+
"centerline_errors": np.array([]),
|
|
61
|
+
"inner_bound_errors": np.array([]),
|
|
62
|
+
"outer_bound_errors": np.array([]),
|
|
63
|
+
"max_error": 0.0, # type: ignore[dict-item]
|
|
64
|
+
"mean_error": 0.0, # type: ignore[dict-item]
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
# Transform GT to car frame
|
|
68
|
+
gt_centerline_car = np.array(
|
|
69
|
+
[car_state.transform_to_car_frame(point) for point in gt_centerline]
|
|
70
|
+
)
|
|
71
|
+
gt_inner_car = np.array([car_state.transform_to_car_frame(point) for point in gt_inner])
|
|
72
|
+
gt_outer_car = np.array([car_state.transform_to_car_frame(point) for point in gt_outer])
|
|
73
|
+
|
|
74
|
+
# Get perceived segments in car frame
|
|
75
|
+
perceived_centerline_car, perceived_inner_car, perceived_outer_car = (
|
|
76
|
+
self.perceived_map.get_perceived_segments(horizon, fov)
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
# Compute errors (using nearest neighbor matching)
|
|
80
|
+
centerline_errors = self._compute_point_errors(gt_centerline_car, perceived_centerline_car)
|
|
81
|
+
inner_bound_errors = self._compute_point_errors(gt_inner_car, perceived_inner_car)
|
|
82
|
+
outer_bound_errors = self._compute_point_errors(gt_outer_car, perceived_outer_car)
|
|
83
|
+
|
|
84
|
+
# Aggregate errors
|
|
85
|
+
all_errors = np.concatenate([centerline_errors, inner_bound_errors, outer_bound_errors])
|
|
86
|
+
max_error = np.max(all_errors) if len(all_errors) > 0 else 0.0
|
|
87
|
+
mean_error = np.mean(all_errors) if len(all_errors) > 0 else 0.0
|
|
88
|
+
|
|
89
|
+
return {
|
|
90
|
+
"centerline_errors": centerline_errors,
|
|
91
|
+
"inner_bound_errors": inner_bound_errors,
|
|
92
|
+
"outer_bound_errors": outer_bound_errors,
|
|
93
|
+
"max_error": max_error, # type: ignore[dict-item]
|
|
94
|
+
"mean_error": mean_error, # type: ignore[dict-item]
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
def _compute_point_errors(
|
|
98
|
+
self, ground_truth_points: np.ndarray, perceived_points: np.ndarray
|
|
99
|
+
) -> np.ndarray:
|
|
100
|
+
"""
|
|
101
|
+
Compute distance errors between ground truth and perceived points.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
ground_truth_points: Ground truth points (N, 2)
|
|
105
|
+
perceived_points: Perceived points (M, 2)
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
Array of errors for each ground truth point
|
|
109
|
+
"""
|
|
110
|
+
if len(perceived_points) == 0:
|
|
111
|
+
return np.zeros(len(ground_truth_points))
|
|
112
|
+
|
|
113
|
+
errors = []
|
|
114
|
+
for gt_point in ground_truth_points:
|
|
115
|
+
# Find nearest perceived point
|
|
116
|
+
distances = np.linalg.norm(perceived_points - gt_point, axis=1)
|
|
117
|
+
min_distance = np.min(distances)
|
|
118
|
+
errors.append(min_distance)
|
|
119
|
+
|
|
120
|
+
return np.array(errors)
|
|
121
|
+
|
|
122
|
+
def detect_errors(
|
|
123
|
+
self, car_state: CarState, horizon: float, fov: float = 2 * np.pi
|
|
124
|
+
) -> dict[str, bool]:
|
|
125
|
+
"""
|
|
126
|
+
Detect if errors exceed threshold.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
car_state: Current car state
|
|
130
|
+
horizon: Maximum distance to consider
|
|
131
|
+
fov: Field of view angle in radians
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
Dictionary with detection flags:
|
|
135
|
+
- has_error: True if any error exceeds threshold
|
|
136
|
+
- centerline_has_error: True if centerline errors exceed threshold
|
|
137
|
+
- inner_bound_has_error: True if inner bound errors exceed threshold
|
|
138
|
+
- outer_bound_has_error: True if outer bound errors exceed threshold
|
|
139
|
+
"""
|
|
140
|
+
errors = self.compute_errors(car_state, horizon, fov)
|
|
141
|
+
|
|
142
|
+
centerline_has_error = bool(
|
|
143
|
+
np.any(errors["centerline_errors"] > self.error_threshold)
|
|
144
|
+
if len(errors["centerline_errors"]) > 0
|
|
145
|
+
else False
|
|
146
|
+
)
|
|
147
|
+
inner_bound_has_error = bool(
|
|
148
|
+
np.any(errors["inner_bound_errors"] > self.error_threshold)
|
|
149
|
+
if len(errors["inner_bound_errors"]) > 0
|
|
150
|
+
else False
|
|
151
|
+
)
|
|
152
|
+
outer_bound_has_error = bool(
|
|
153
|
+
np.any(errors["outer_bound_errors"] > self.error_threshold)
|
|
154
|
+
if len(errors["outer_bound_errors"]) > 0
|
|
155
|
+
else False
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
has_error = bool(centerline_has_error or inner_bound_has_error or outer_bound_has_error)
|
|
159
|
+
|
|
160
|
+
return {
|
|
161
|
+
"has_error": has_error,
|
|
162
|
+
"centerline_has_error": centerline_has_error,
|
|
163
|
+
"inner_bound_has_error": inner_bound_has_error,
|
|
164
|
+
"outer_bound_has_error": outer_bound_has_error,
|
|
165
|
+
}
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
"""Filtering components for pose and object estimation."""
|
|
2
|
+
|
|
3
|
+
from simple_autonomous_car.filters.base_filter import BaseFilter
|
|
4
|
+
from simple_autonomous_car.filters.kalman_filter import KalmanFilter
|
|
5
|
+
from simple_autonomous_car.filters.particle_filter import ParticleFilter
|
|
6
|
+
|
|
7
|
+
__all__ = ["BaseFilter", "KalmanFilter", "ParticleFilter"]
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
"""Base filter class for pose and object estimation."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class BaseFilter(ABC):
|
|
9
|
+
"""
|
|
10
|
+
Base class for all filters.
|
|
11
|
+
|
|
12
|
+
Filters estimate state (position, velocity, etc.) from noisy measurements.
|
|
13
|
+
They can be used for:
|
|
14
|
+
- Ego pose estimation (car's own position/heading)
|
|
15
|
+
- Object tracking (obstacle positions/velocities)
|
|
16
|
+
|
|
17
|
+
Attributes
|
|
18
|
+
----------
|
|
19
|
+
name : str
|
|
20
|
+
Filter name/identifier.
|
|
21
|
+
enabled : bool
|
|
22
|
+
Whether the filter is enabled.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, name: str = "filter", enabled: bool = True):
|
|
26
|
+
"""
|
|
27
|
+
Initialize base filter.
|
|
28
|
+
|
|
29
|
+
Parameters
|
|
30
|
+
----------
|
|
31
|
+
name : str, default="filter"
|
|
32
|
+
Filter name/identifier.
|
|
33
|
+
enabled : bool, default=True
|
|
34
|
+
Whether the filter is enabled.
|
|
35
|
+
"""
|
|
36
|
+
self.name = name
|
|
37
|
+
self.enabled = enabled
|
|
38
|
+
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def predict(self, dt: float, control: dict[str, float] | None = None) -> None:
|
|
41
|
+
"""
|
|
42
|
+
Predict state forward in time.
|
|
43
|
+
|
|
44
|
+
Parameters
|
|
45
|
+
----------
|
|
46
|
+
dt : float
|
|
47
|
+
Time step in seconds.
|
|
48
|
+
control : dict, optional
|
|
49
|
+
Control inputs (e.g., {"acceleration": 0.5, "steering_rate": 0.1}).
|
|
50
|
+
"""
|
|
51
|
+
pass
|
|
52
|
+
|
|
53
|
+
@abstractmethod
|
|
54
|
+
def update(
|
|
55
|
+
self, measurement: np.ndarray, measurement_covariance: np.ndarray | None = None
|
|
56
|
+
) -> None:
|
|
57
|
+
"""
|
|
58
|
+
Update state estimate with measurement.
|
|
59
|
+
|
|
60
|
+
Parameters
|
|
61
|
+
----------
|
|
62
|
+
measurement : np.ndarray
|
|
63
|
+
Measurement vector (e.g., [x, y] for position).
|
|
64
|
+
measurement_covariance : np.ndarray, optional
|
|
65
|
+
Measurement noise covariance matrix.
|
|
66
|
+
"""
|
|
67
|
+
pass
|
|
68
|
+
|
|
69
|
+
@abstractmethod
|
|
70
|
+
def get_state(self) -> np.ndarray:
|
|
71
|
+
"""
|
|
72
|
+
Get current state estimate.
|
|
73
|
+
|
|
74
|
+
Returns
|
|
75
|
+
-------
|
|
76
|
+
np.ndarray
|
|
77
|
+
State vector (e.g., [x, y, vx, vy] for position and velocity).
|
|
78
|
+
"""
|
|
79
|
+
pass
|
|
80
|
+
|
|
81
|
+
@abstractmethod
|
|
82
|
+
def get_covariance(self) -> np.ndarray:
|
|
83
|
+
"""
|
|
84
|
+
Get current state covariance (uncertainty).
|
|
85
|
+
|
|
86
|
+
Returns
|
|
87
|
+
-------
|
|
88
|
+
np.ndarray
|
|
89
|
+
Covariance matrix.
|
|
90
|
+
"""
|
|
91
|
+
pass
|
|
92
|
+
|
|
93
|
+
def is_enabled(self) -> bool:
|
|
94
|
+
"""Check if filter is enabled."""
|
|
95
|
+
return self.enabled
|
|
96
|
+
|
|
97
|
+
def enable(self) -> None:
|
|
98
|
+
"""Enable the filter."""
|
|
99
|
+
self.enabled = True
|
|
100
|
+
|
|
101
|
+
def disable(self) -> None:
|
|
102
|
+
"""Disable the filter."""
|
|
103
|
+
self.enabled = False
|
|
104
|
+
|
|
105
|
+
def reset(
|
|
106
|
+
self, initial_state: np.ndarray, initial_covariance: np.ndarray | None = None
|
|
107
|
+
) -> None:
|
|
108
|
+
"""
|
|
109
|
+
Reset filter to initial state.
|
|
110
|
+
|
|
111
|
+
Parameters
|
|
112
|
+
----------
|
|
113
|
+
initial_state : np.ndarray
|
|
114
|
+
Initial state vector.
|
|
115
|
+
initial_covariance : np.ndarray, optional
|
|
116
|
+
Initial covariance matrix.
|
|
117
|
+
"""
|
|
118
|
+
# Default implementation - subclasses should override
|
|
119
|
+
pass
|
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
"""Kalman filter for state estimation."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from simple_autonomous_car.filters.base_filter import BaseFilter
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class KalmanFilter(BaseFilter):
|
|
9
|
+
"""
|
|
10
|
+
Kalman filter for state estimation.
|
|
11
|
+
|
|
12
|
+
Estimates state (position, velocity, etc.) from noisy measurements
|
|
13
|
+
using the Kalman filter algorithm.
|
|
14
|
+
|
|
15
|
+
Parameters
|
|
16
|
+
----------
|
|
17
|
+
initial_state : np.ndarray
|
|
18
|
+
Initial state vector (e.g., [x, y, vx, vy]).
|
|
19
|
+
initial_covariance : np.ndarray
|
|
20
|
+
Initial covariance matrix.
|
|
21
|
+
process_noise : np.ndarray
|
|
22
|
+
Process noise covariance matrix (Q).
|
|
23
|
+
measurement_noise : np.ndarray
|
|
24
|
+
Measurement noise covariance matrix (R).
|
|
25
|
+
name : str, default="kalman_filter"
|
|
26
|
+
Filter name.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
initial_state: np.ndarray,
|
|
32
|
+
initial_covariance: np.ndarray,
|
|
33
|
+
process_noise: np.ndarray,
|
|
34
|
+
measurement_noise: np.ndarray,
|
|
35
|
+
name: str = "kalman_filter",
|
|
36
|
+
):
|
|
37
|
+
super().__init__(name=name)
|
|
38
|
+
self.state = np.asarray(initial_state, dtype=np.float64).copy()
|
|
39
|
+
self.covariance = np.asarray(initial_covariance, dtype=np.float64).copy()
|
|
40
|
+
self.process_noise = np.asarray(process_noise, dtype=np.float64)
|
|
41
|
+
self.measurement_noise = np.asarray(measurement_noise, dtype=np.float64)
|
|
42
|
+
|
|
43
|
+
# State dimension
|
|
44
|
+
self.state_dim = len(initial_state)
|
|
45
|
+
self.measurement_dim = measurement_noise.shape[0]
|
|
46
|
+
|
|
47
|
+
# State transition matrix (identity by default, can be overridden)
|
|
48
|
+
self.F = np.eye(self.state_dim)
|
|
49
|
+
# Measurement matrix (identity by default, can be overridden)
|
|
50
|
+
self.H = np.eye(self.measurement_dim, self.state_dim)
|
|
51
|
+
|
|
52
|
+
def predict(self, dt: float, control: dict[str, float] | None = None) -> None:
|
|
53
|
+
"""
|
|
54
|
+
Predict state forward in time.
|
|
55
|
+
|
|
56
|
+
Parameters
|
|
57
|
+
----------
|
|
58
|
+
dt : float
|
|
59
|
+
Time step in seconds.
|
|
60
|
+
control : dict, optional
|
|
61
|
+
Control inputs (not used in basic Kalman filter).
|
|
62
|
+
"""
|
|
63
|
+
if not self.enabled:
|
|
64
|
+
return
|
|
65
|
+
|
|
66
|
+
# Update state transition matrix with dt if needed
|
|
67
|
+
# For constant velocity model: x' = x + vx*dt, y' = y + vy*dt
|
|
68
|
+
if self.state_dim >= 4: # Has velocity components
|
|
69
|
+
f_matrix = np.eye(self.state_dim)
|
|
70
|
+
f_matrix[0, 2] = dt # x += vx*dt
|
|
71
|
+
f_matrix[1, 3] = dt # y += vy*dt
|
|
72
|
+
self.F = f_matrix
|
|
73
|
+
|
|
74
|
+
# Predict state: x' = F * x
|
|
75
|
+
self.state = self.F @ self.state
|
|
76
|
+
|
|
77
|
+
# Predict covariance: P' = F * P * F^T + Q
|
|
78
|
+
self.covariance = self.F @ self.covariance @ self.F.T + self.process_noise
|
|
79
|
+
|
|
80
|
+
def update(
|
|
81
|
+
self, measurement: np.ndarray, measurement_covariance: np.ndarray | None = None
|
|
82
|
+
) -> None:
|
|
83
|
+
"""
|
|
84
|
+
Update state estimate with measurement.
|
|
85
|
+
|
|
86
|
+
Parameters
|
|
87
|
+
----------
|
|
88
|
+
measurement : np.ndarray
|
|
89
|
+
Measurement vector.
|
|
90
|
+
measurement_covariance : np.ndarray, optional
|
|
91
|
+
Measurement noise covariance (uses default if not provided).
|
|
92
|
+
"""
|
|
93
|
+
if not self.enabled:
|
|
94
|
+
return
|
|
95
|
+
|
|
96
|
+
measurement = np.asarray(measurement, dtype=np.float64)
|
|
97
|
+
r_matrix = (
|
|
98
|
+
measurement_covariance if measurement_covariance is not None else self.measurement_noise
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
# Innovation (measurement residual)
|
|
102
|
+
y = measurement - self.H @ self.state
|
|
103
|
+
|
|
104
|
+
# Innovation covariance
|
|
105
|
+
s_matrix = self.H @ self.covariance @ self.H.T + r_matrix
|
|
106
|
+
|
|
107
|
+
# Kalman gain
|
|
108
|
+
k_gain = self.covariance @ self.H.T @ np.linalg.inv(s_matrix)
|
|
109
|
+
|
|
110
|
+
# Update state: x = x + K * y
|
|
111
|
+
self.state = self.state + k_gain @ y
|
|
112
|
+
|
|
113
|
+
# Update covariance: P = (I - K*H) * P
|
|
114
|
+
identity_matrix = np.eye(self.state_dim)
|
|
115
|
+
self.covariance = (identity_matrix - k_gain @ self.H) @ self.covariance
|
|
116
|
+
|
|
117
|
+
def get_state(self) -> np.ndarray:
|
|
118
|
+
"""Get current state estimate."""
|
|
119
|
+
return np.array(self.state.copy(), dtype=np.float64)
|
|
120
|
+
|
|
121
|
+
def get_covariance(self) -> np.ndarray:
|
|
122
|
+
"""Get current state covariance."""
|
|
123
|
+
return np.array(self.covariance.copy(), dtype=np.float64)
|
|
124
|
+
|
|
125
|
+
def reset(
|
|
126
|
+
self, initial_state: np.ndarray, initial_covariance: np.ndarray | None = None
|
|
127
|
+
) -> None:
|
|
128
|
+
"""Reset filter to initial state."""
|
|
129
|
+
self.state = np.asarray(initial_state, dtype=np.float64).copy()
|
|
130
|
+
if initial_covariance is not None:
|
|
131
|
+
self.covariance = np.asarray(initial_covariance, dtype=np.float64).copy()
|