matrice-analytics 0.1.60__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. matrice_analytics/__init__.py +28 -0
  2. matrice_analytics/boundary_drawing_internal/README.md +305 -0
  3. matrice_analytics/boundary_drawing_internal/__init__.py +45 -0
  4. matrice_analytics/boundary_drawing_internal/boundary_drawing_internal.py +1207 -0
  5. matrice_analytics/boundary_drawing_internal/boundary_drawing_tool.py +429 -0
  6. matrice_analytics/boundary_drawing_internal/boundary_tool_template.html +1036 -0
  7. matrice_analytics/boundary_drawing_internal/data/.gitignore +12 -0
  8. matrice_analytics/boundary_drawing_internal/example_usage.py +206 -0
  9. matrice_analytics/boundary_drawing_internal/usage/README.md +110 -0
  10. matrice_analytics/boundary_drawing_internal/usage/boundary_drawer_launcher.py +102 -0
  11. matrice_analytics/boundary_drawing_internal/usage/simple_boundary_launcher.py +107 -0
  12. matrice_analytics/post_processing/README.md +455 -0
  13. matrice_analytics/post_processing/__init__.py +732 -0
  14. matrice_analytics/post_processing/advanced_tracker/README.md +650 -0
  15. matrice_analytics/post_processing/advanced_tracker/__init__.py +17 -0
  16. matrice_analytics/post_processing/advanced_tracker/base.py +99 -0
  17. matrice_analytics/post_processing/advanced_tracker/config.py +77 -0
  18. matrice_analytics/post_processing/advanced_tracker/kalman_filter.py +370 -0
  19. matrice_analytics/post_processing/advanced_tracker/matching.py +195 -0
  20. matrice_analytics/post_processing/advanced_tracker/strack.py +230 -0
  21. matrice_analytics/post_processing/advanced_tracker/tracker.py +367 -0
  22. matrice_analytics/post_processing/config.py +146 -0
  23. matrice_analytics/post_processing/core/__init__.py +63 -0
  24. matrice_analytics/post_processing/core/base.py +704 -0
  25. matrice_analytics/post_processing/core/config.py +3291 -0
  26. matrice_analytics/post_processing/core/config_utils.py +925 -0
  27. matrice_analytics/post_processing/face_reg/__init__.py +43 -0
  28. matrice_analytics/post_processing/face_reg/compare_similarity.py +556 -0
  29. matrice_analytics/post_processing/face_reg/embedding_manager.py +950 -0
  30. matrice_analytics/post_processing/face_reg/face_recognition.py +2234 -0
  31. matrice_analytics/post_processing/face_reg/face_recognition_client.py +606 -0
  32. matrice_analytics/post_processing/face_reg/people_activity_logging.py +321 -0
  33. matrice_analytics/post_processing/ocr/__init__.py +0 -0
  34. matrice_analytics/post_processing/ocr/easyocr_extractor.py +250 -0
  35. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/__init__.py +9 -0
  36. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/__init__.py +4 -0
  37. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/cli.py +33 -0
  38. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/dataset_stats.py +139 -0
  39. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/export.py +398 -0
  40. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/train.py +447 -0
  41. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/utils.py +129 -0
  42. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/valid.py +93 -0
  43. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/validate_dataset.py +240 -0
  44. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/visualize_augmentation.py +176 -0
  45. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/visualize_predictions.py +96 -0
  46. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/__init__.py +3 -0
  47. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/process.py +246 -0
  48. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/types.py +60 -0
  49. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/utils.py +87 -0
  50. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/__init__.py +3 -0
  51. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/config.py +82 -0
  52. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/hub.py +141 -0
  53. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/plate_recognizer.py +323 -0
  54. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/py.typed +0 -0
  55. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/__init__.py +0 -0
  56. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/data/__init__.py +0 -0
  57. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/data/augmentation.py +101 -0
  58. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/data/dataset.py +97 -0
  59. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/__init__.py +0 -0
  60. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/config.py +114 -0
  61. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/layers.py +553 -0
  62. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/loss.py +55 -0
  63. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/metric.py +86 -0
  64. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/model_builders.py +95 -0
  65. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/model_schema.py +395 -0
  66. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/utilities/__init__.py +0 -0
  67. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/utilities/backend_utils.py +38 -0
  68. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/utilities/utils.py +214 -0
  69. matrice_analytics/post_processing/ocr/postprocessing.py +270 -0
  70. matrice_analytics/post_processing/ocr/preprocessing.py +52 -0
  71. matrice_analytics/post_processing/post_processor.py +1175 -0
  72. matrice_analytics/post_processing/test_cases/__init__.py +1 -0
  73. matrice_analytics/post_processing/test_cases/run_tests.py +143 -0
  74. matrice_analytics/post_processing/test_cases/test_advanced_customer_service.py +841 -0
  75. matrice_analytics/post_processing/test_cases/test_basic_counting_tracking.py +523 -0
  76. matrice_analytics/post_processing/test_cases/test_comprehensive.py +531 -0
  77. matrice_analytics/post_processing/test_cases/test_config.py +852 -0
  78. matrice_analytics/post_processing/test_cases/test_customer_service.py +585 -0
  79. matrice_analytics/post_processing/test_cases/test_data_generators.py +583 -0
  80. matrice_analytics/post_processing/test_cases/test_people_counting.py +510 -0
  81. matrice_analytics/post_processing/test_cases/test_processor.py +524 -0
  82. matrice_analytics/post_processing/test_cases/test_usecases.py +165 -0
  83. matrice_analytics/post_processing/test_cases/test_utilities.py +356 -0
  84. matrice_analytics/post_processing/test_cases/test_utils.py +743 -0
  85. matrice_analytics/post_processing/usecases/Histopathological_Cancer_Detection_img.py +604 -0
  86. matrice_analytics/post_processing/usecases/__init__.py +267 -0
  87. matrice_analytics/post_processing/usecases/abandoned_object_detection.py +797 -0
  88. matrice_analytics/post_processing/usecases/advanced_customer_service.py +1601 -0
  89. matrice_analytics/post_processing/usecases/age_detection.py +842 -0
  90. matrice_analytics/post_processing/usecases/age_gender_detection.py +1085 -0
  91. matrice_analytics/post_processing/usecases/anti_spoofing_detection.py +656 -0
  92. matrice_analytics/post_processing/usecases/assembly_line_detection.py +841 -0
  93. matrice_analytics/post_processing/usecases/banana_defect_detection.py +624 -0
  94. matrice_analytics/post_processing/usecases/basic_counting_tracking.py +667 -0
  95. matrice_analytics/post_processing/usecases/blood_cancer_detection_img.py +881 -0
  96. matrice_analytics/post_processing/usecases/car_damage_detection.py +834 -0
  97. matrice_analytics/post_processing/usecases/car_part_segmentation.py +946 -0
  98. matrice_analytics/post_processing/usecases/car_service.py +1601 -0
  99. matrice_analytics/post_processing/usecases/cardiomegaly_classification.py +864 -0
  100. matrice_analytics/post_processing/usecases/cell_microscopy_segmentation.py +897 -0
  101. matrice_analytics/post_processing/usecases/chicken_pose_detection.py +648 -0
  102. matrice_analytics/post_processing/usecases/child_monitoring.py +814 -0
  103. matrice_analytics/post_processing/usecases/color/clip.py +660 -0
  104. matrice_analytics/post_processing/usecases/color/clip_processor/merges.txt +48895 -0
  105. matrice_analytics/post_processing/usecases/color/clip_processor/preprocessor_config.json +28 -0
  106. matrice_analytics/post_processing/usecases/color/clip_processor/special_tokens_map.json +30 -0
  107. matrice_analytics/post_processing/usecases/color/clip_processor/tokenizer.json +245079 -0
  108. matrice_analytics/post_processing/usecases/color/clip_processor/tokenizer_config.json +32 -0
  109. matrice_analytics/post_processing/usecases/color/clip_processor/vocab.json +1 -0
  110. matrice_analytics/post_processing/usecases/color/color_map_utils.py +70 -0
  111. matrice_analytics/post_processing/usecases/color/color_mapper.py +468 -0
  112. matrice_analytics/post_processing/usecases/color_detection.py +1936 -0
  113. matrice_analytics/post_processing/usecases/color_map_utils.py +70 -0
  114. matrice_analytics/post_processing/usecases/concrete_crack_detection.py +827 -0
  115. matrice_analytics/post_processing/usecases/crop_weed_detection.py +781 -0
  116. matrice_analytics/post_processing/usecases/customer_service.py +1008 -0
  117. matrice_analytics/post_processing/usecases/defect_detection_products.py +936 -0
  118. matrice_analytics/post_processing/usecases/distracted_driver_detection.py +822 -0
  119. matrice_analytics/post_processing/usecases/drone_traffic_monitoring.py +585 -0
  120. matrice_analytics/post_processing/usecases/drowsy_driver_detection.py +829 -0
  121. matrice_analytics/post_processing/usecases/dwell_detection.py +829 -0
  122. matrice_analytics/post_processing/usecases/emergency_vehicle_detection.py +827 -0
  123. matrice_analytics/post_processing/usecases/face_emotion.py +813 -0
  124. matrice_analytics/post_processing/usecases/face_recognition.py +827 -0
  125. matrice_analytics/post_processing/usecases/fashion_detection.py +835 -0
  126. matrice_analytics/post_processing/usecases/field_mapping.py +902 -0
  127. matrice_analytics/post_processing/usecases/fire_detection.py +1146 -0
  128. matrice_analytics/post_processing/usecases/flare_analysis.py +836 -0
  129. matrice_analytics/post_processing/usecases/flower_segmentation.py +1006 -0
  130. matrice_analytics/post_processing/usecases/gas_leak_detection.py +837 -0
  131. matrice_analytics/post_processing/usecases/gender_detection.py +832 -0
  132. matrice_analytics/post_processing/usecases/human_activity_recognition.py +871 -0
  133. matrice_analytics/post_processing/usecases/intrusion_detection.py +1672 -0
  134. matrice_analytics/post_processing/usecases/leaf.py +821 -0
  135. matrice_analytics/post_processing/usecases/leaf_disease.py +840 -0
  136. matrice_analytics/post_processing/usecases/leak_detection.py +837 -0
  137. matrice_analytics/post_processing/usecases/license_plate_detection.py +1188 -0
  138. matrice_analytics/post_processing/usecases/license_plate_monitoring.py +1781 -0
  139. matrice_analytics/post_processing/usecases/litter_monitoring.py +717 -0
  140. matrice_analytics/post_processing/usecases/mask_detection.py +869 -0
  141. matrice_analytics/post_processing/usecases/natural_disaster.py +907 -0
  142. matrice_analytics/post_processing/usecases/parking.py +787 -0
  143. matrice_analytics/post_processing/usecases/parking_space_detection.py +822 -0
  144. matrice_analytics/post_processing/usecases/pcb_defect_detection.py +888 -0
  145. matrice_analytics/post_processing/usecases/pedestrian_detection.py +808 -0
  146. matrice_analytics/post_processing/usecases/people_counting.py +706 -0
  147. matrice_analytics/post_processing/usecases/people_counting_bckp.py +1683 -0
  148. matrice_analytics/post_processing/usecases/people_tracking.py +1842 -0
  149. matrice_analytics/post_processing/usecases/pipeline_detection.py +605 -0
  150. matrice_analytics/post_processing/usecases/plaque_segmentation_img.py +874 -0
  151. matrice_analytics/post_processing/usecases/pothole_segmentation.py +915 -0
  152. matrice_analytics/post_processing/usecases/ppe_compliance.py +645 -0
  153. matrice_analytics/post_processing/usecases/price_tag_detection.py +822 -0
  154. matrice_analytics/post_processing/usecases/proximity_detection.py +1901 -0
  155. matrice_analytics/post_processing/usecases/road_lane_detection.py +623 -0
  156. matrice_analytics/post_processing/usecases/road_traffic_density.py +832 -0
  157. matrice_analytics/post_processing/usecases/road_view_segmentation.py +915 -0
  158. matrice_analytics/post_processing/usecases/shelf_inventory_detection.py +583 -0
  159. matrice_analytics/post_processing/usecases/shoplifting_detection.py +822 -0
  160. matrice_analytics/post_processing/usecases/shopping_cart_analysis.py +899 -0
  161. matrice_analytics/post_processing/usecases/skin_cancer_classification_img.py +864 -0
  162. matrice_analytics/post_processing/usecases/smoker_detection.py +833 -0
  163. matrice_analytics/post_processing/usecases/solar_panel.py +810 -0
  164. matrice_analytics/post_processing/usecases/suspicious_activity_detection.py +1030 -0
  165. matrice_analytics/post_processing/usecases/template_usecase.py +380 -0
  166. matrice_analytics/post_processing/usecases/theft_detection.py +648 -0
  167. matrice_analytics/post_processing/usecases/traffic_sign_monitoring.py +724 -0
  168. matrice_analytics/post_processing/usecases/underground_pipeline_defect_detection.py +775 -0
  169. matrice_analytics/post_processing/usecases/underwater_pollution_detection.py +842 -0
  170. matrice_analytics/post_processing/usecases/vehicle_monitoring.py +1029 -0
  171. matrice_analytics/post_processing/usecases/warehouse_object_segmentation.py +899 -0
  172. matrice_analytics/post_processing/usecases/waterbody_segmentation.py +923 -0
  173. matrice_analytics/post_processing/usecases/weapon_detection.py +771 -0
  174. matrice_analytics/post_processing/usecases/weld_defect_detection.py +615 -0
  175. matrice_analytics/post_processing/usecases/wildlife_monitoring.py +898 -0
  176. matrice_analytics/post_processing/usecases/windmill_maintenance.py +834 -0
  177. matrice_analytics/post_processing/usecases/wound_segmentation.py +856 -0
  178. matrice_analytics/post_processing/utils/__init__.py +150 -0
  179. matrice_analytics/post_processing/utils/advanced_counting_utils.py +400 -0
  180. matrice_analytics/post_processing/utils/advanced_helper_utils.py +317 -0
  181. matrice_analytics/post_processing/utils/advanced_tracking_utils.py +461 -0
  182. matrice_analytics/post_processing/utils/alerting_utils.py +213 -0
  183. matrice_analytics/post_processing/utils/category_mapping_utils.py +94 -0
  184. matrice_analytics/post_processing/utils/color_utils.py +592 -0
  185. matrice_analytics/post_processing/utils/counting_utils.py +182 -0
  186. matrice_analytics/post_processing/utils/filter_utils.py +261 -0
  187. matrice_analytics/post_processing/utils/format_utils.py +293 -0
  188. matrice_analytics/post_processing/utils/geometry_utils.py +300 -0
  189. matrice_analytics/post_processing/utils/smoothing_utils.py +358 -0
  190. matrice_analytics/post_processing/utils/tracking_utils.py +234 -0
  191. matrice_analytics/py.typed +0 -0
  192. matrice_analytics-0.1.60.dist-info/METADATA +481 -0
  193. matrice_analytics-0.1.60.dist-info/RECORD +196 -0
  194. matrice_analytics-0.1.60.dist-info/WHEEL +5 -0
  195. matrice_analytics-0.1.60.dist-info/licenses/LICENSE.txt +21 -0
  196. matrice_analytics-0.1.60.dist-info/top_level.txt +1 -0
@@ -0,0 +1,99 @@
1
+ """
2
+ Base tracking classes for advanced tracker implementation.
3
+
4
+ This module provides the foundational classes for object tracking,
5
+ including track states and base track functionality.
6
+ """
7
+
8
+ from collections import OrderedDict
9
+ from typing import Any
10
+ import numpy as np
11
+
12
+
13
+ class TrackState:
14
+ """
15
+ Enumeration class representing the possible states of an object being tracked.
16
+
17
+ Attributes:
18
+ New (int): State when the object is newly detected.
19
+ Tracked (int): State when the object is successfully tracked in subsequent frames.
20
+ Lost (int): State when the object is no longer tracked.
21
+ Removed (int): State when the object is removed from tracking.
22
+ """
23
+
24
+ New = 0
25
+ Tracked = 1
26
+ Lost = 2
27
+ Removed = 3
28
+
29
+
30
+ class BaseTrack:
31
+ """
32
+ Base class for object tracking, providing foundational attributes and methods.
33
+
34
+ Attributes:
35
+ _count (int): Class-level counter for unique track IDs.
36
+ track_id (int): Unique identifier for the track.
37
+ is_activated (bool): Flag indicating whether the track is currently active.
38
+ state (TrackState): Current state of the track.
39
+ history (OrderedDict): Ordered history of the track's states.
40
+ features (list): List of features extracted from the object for tracking.
41
+ curr_feature (Any): The current feature of the object being tracked.
42
+ score (float): The confidence score of the tracking.
43
+ start_frame (int): The frame number where tracking started.
44
+ frame_id (int): The most recent frame ID processed by the track.
45
+ time_since_update (int): Frames passed since the last update.
46
+ location (tuple): The location of the object in the context of multi-camera tracking.
47
+ """
48
+
49
+ _count = 0
50
+
51
+ def __init__(self):
52
+ """Initialize a new track with a unique ID and foundational tracking attributes."""
53
+ self.track_id = 0
54
+ self.is_activated = False
55
+ self.state = TrackState.New
56
+ self.history = OrderedDict()
57
+ self.features = []
58
+ self.curr_feature = None
59
+ self.score = 0
60
+ self.start_frame = 0
61
+ self.frame_id = 0
62
+ self.time_since_update = 0
63
+ self.location = (np.inf, np.inf)
64
+
65
+ @property
66
+ def end_frame(self) -> int:
67
+ """Return the ID of the most recent frame where the object was tracked."""
68
+ return self.frame_id
69
+
70
+ @staticmethod
71
+ def next_id() -> int:
72
+ """Increment and return the next unique global track ID for object tracking."""
73
+ BaseTrack._count += 1
74
+ return BaseTrack._count
75
+
76
+ def activate(self, *args: Any) -> None:
77
+ """Activate the track with provided arguments, initializing necessary attributes for tracking."""
78
+ raise NotImplementedError
79
+
80
+ def predict(self) -> None:
81
+ """Predict the next state of the track based on the current state and tracking model."""
82
+ raise NotImplementedError
83
+
84
+ def update(self, *args: Any, **kwargs: Any) -> None:
85
+ """Update the track with new observations and data, modifying its state and attributes accordingly."""
86
+ raise NotImplementedError
87
+
88
+ def mark_lost(self) -> None:
89
+ """Mark the track as lost by updating its state to TrackState.Lost."""
90
+ self.state = TrackState.Lost
91
+
92
+ def mark_removed(self) -> None:
93
+ """Mark the track as removed by setting its state to TrackState.Removed."""
94
+ self.state = TrackState.Removed
95
+
96
+ @staticmethod
97
+ def reset_id() -> None:
98
+ """Reset the global track ID counter to its initial value."""
99
+ BaseTrack._count = 0
@@ -0,0 +1,77 @@
1
+ """
2
+ Configuration classes for advanced tracker.
3
+
4
+ This module provides configuration classes for the advanced tracker,
5
+ including parameters for tracking algorithms and thresholds.
6
+ """
7
+
8
+ from dataclasses import dataclass, field
9
+ from typing import Optional
10
+
11
+
12
+ @dataclass
13
+ class TrackerConfig:
14
+ """
15
+ Configuration for advanced tracker.
16
+
17
+ This class contains all the parameters needed to configure the tracking algorithm,
18
+ including thresholds, buffer sizes, and algorithm-specific settings.
19
+ """
20
+
21
+ # Tracking thresholds
22
+ track_high_thresh: float = 0.7
23
+ track_low_thresh: float = 0.1
24
+ new_track_thresh: float = 0.7
25
+ match_thresh: float = 0.8
26
+
27
+ # Buffer settings
28
+ track_buffer: int = 600
29
+ max_time_lost: int = 600
30
+
31
+ # Algorithm settings
32
+ fuse_score: bool = True
33
+ enable_gmc: bool = True
34
+ gmc_method: str = "sparseOptFlow" # "orb", "sift", "ecc", "sparseOptFlow", "none"
35
+ gmc_downscale: int = 2
36
+
37
+ # Frame rate (used for max_time_lost calculation)
38
+ frame_rate: int = 30
39
+
40
+ # Output format settings
41
+ output_format: str = "tracking" # "tracking" or "detection"
42
+
43
+ # Additional settings
44
+ enable_smoothing: bool = True
45
+ smoothing_algorithm: str = "observability" # "window" or "observability"
46
+ smoothing_window_size: int = 20
47
+ smoothing_cooldown_frames: int = 5
48
+
49
+ def __post_init__(self):
50
+ """Validate configuration parameters."""
51
+ if not 0.0 <= self.track_high_thresh <= 1.0:
52
+ raise ValueError(f"track_high_thresh must be between 0.0 and 1.0, got {self.track_high_thresh}")
53
+
54
+ if not 0.0 <= self.track_low_thresh <= 1.0:
55
+ raise ValueError(f"track_low_thresh must be between 0.0 and 1.0, got {self.track_low_thresh}")
56
+
57
+ if not 0.0 <= self.new_track_thresh <= 1.0:
58
+ raise ValueError(f"new_track_thresh must be between 0.0 and 1.0, got {self.new_track_thresh}")
59
+
60
+ if not 0.0 <= self.match_thresh <= 1.0:
61
+ raise ValueError(f"match_thresh must be between 0.0 and 1.0, got {self.match_thresh}")
62
+
63
+ if self.track_buffer <= 0:
64
+ raise ValueError(f"track_buffer must be positive, got {self.track_buffer}")
65
+
66
+ if self.frame_rate <= 0:
67
+ raise ValueError(f"frame_rate must be positive, got {self.frame_rate}")
68
+
69
+ if self.gmc_method not in ["orb", "sift", "ecc", "sparseOptFlow", "none"]:
70
+ raise ValueError(f"Invalid gmc_method: {self.gmc_method}")
71
+
72
+ if self.output_format not in ["tracking", "detection"]:
73
+ raise ValueError(f"Invalid output_format: {self.output_format}")
74
+
75
+ # Calculate max_time_lost if not explicitly set
76
+ if self.max_time_lost == 30: # Default value
77
+ self.max_time_lost = int(self.frame_rate / 30.0 * self.track_buffer)
@@ -0,0 +1,370 @@
1
+ """
2
+ Kalman filter implementation for advanced tracker.
3
+
4
+ This module provides Kalman filter implementations for tracking bounding boxes,
5
+ including both XYAH and XYWH formats.
6
+ """
7
+
8
+ import numpy as np
9
+ import scipy.linalg
10
+
11
+
12
+ class KalmanFilterXYAH:
13
+ """
14
+ A KalmanFilterXYAH class for tracking bounding boxes in image space using a Kalman filter.
15
+
16
+ Implements a simple Kalman filter for tracking bounding boxes in image space. The 8-dimensional state space
17
+ (x, y, a, h, vx, vy, va, vh) contains the bounding box center position (x, y), aspect ratio a, height h, and their
18
+ respective velocities. Object motion follows a constant velocity model, and bounding box location (x, y, a, h) is
19
+ taken as a direct observation of the state space (linear observation model).
20
+ """
21
+
22
+ def __init__(self):
23
+ """
24
+ Initialize Kalman filter model matrices with motion and observation uncertainty weights.
25
+
26
+ The Kalman filter is initialized with an 8-dimensional state space (x, y, a, h, vx, vy, va, vh), where (x, y)
27
+ represents the bounding box center position, 'a' is the aspect ratio, 'h' is the height, and their respective
28
+ velocities are (vx, vy, va, vh). The filter uses a constant velocity model for object motion and a linear
29
+ observation model for bounding box location.
30
+ """
31
+ ndim, dt = 4, 1.0
32
+
33
+ # Create Kalman filter model matrices
34
+ self._motion_mat = np.eye(2 * ndim, 2 * ndim)
35
+ for i in range(ndim):
36
+ self._motion_mat[i, ndim + i] = dt
37
+ self._update_mat = np.eye(ndim, 2 * ndim)
38
+
39
+ # Motion and observation uncertainty are chosen relative to the current state estimate
40
+ self._std_weight_position = 1.0 / 20
41
+ self._std_weight_velocity = 1.0 / 160
42
+
43
+ def initiate(self, measurement: np.ndarray):
44
+ """
45
+ Create a track from an unassociated measurement.
46
+
47
+ Args:
48
+ measurement (np.ndarray): Bounding box coordinates (x, y, a, h) with center position (x, y), aspect ratio a,
49
+ and height h.
50
+
51
+ Returns:
52
+ mean (np.ndarray): Mean vector (8-dimensional) of the new track. Unobserved velocities are initialized to 0 mean.
53
+ covariance (np.ndarray): Covariance matrix (8x8 dimensional) of the new track.
54
+ """
55
+ mean_pos = measurement
56
+ mean_vel = np.zeros_like(mean_pos)
57
+ mean = np.r_[mean_pos, mean_vel]
58
+
59
+ std = [
60
+ 2 * self._std_weight_position * measurement[3],
61
+ 2 * self._std_weight_position * measurement[3],
62
+ 1e-2,
63
+ 2 * self._std_weight_position * measurement[3],
64
+ 10 * self._std_weight_velocity * measurement[3],
65
+ 10 * self._std_weight_velocity * measurement[3],
66
+ 1e-5,
67
+ 10 * self._std_weight_velocity * measurement[3],
68
+ ]
69
+ covariance = np.diag(np.square(std))
70
+ return mean, covariance
71
+
72
+ def predict(self, mean: np.ndarray, covariance: np.ndarray):
73
+ """
74
+ Run Kalman filter prediction step.
75
+
76
+ Args:
77
+ mean (np.ndarray): The 8-dimensional mean vector of the object state at the previous time step.
78
+ covariance (np.ndarray): The 8x8-dimensional covariance matrix of the object state at the previous time step.
79
+
80
+ Returns:
81
+ mean (np.ndarray): Mean vector of the predicted state. Unobserved velocities are initialized to 0 mean.
82
+ covariance (np.ndarray): Covariance matrix of the predicted state.
83
+ """
84
+ std_pos = [
85
+ self._std_weight_position * mean[3],
86
+ self._std_weight_position * mean[3],
87
+ 1e-2,
88
+ self._std_weight_position * mean[3],
89
+ ]
90
+ std_vel = [
91
+ self._std_weight_velocity * mean[3],
92
+ self._std_weight_velocity * mean[3],
93
+ 1e-5,
94
+ self._std_weight_velocity * mean[3],
95
+ ]
96
+ motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
97
+
98
+ mean = np.dot(mean, self._motion_mat.T)
99
+ covariance = np.linalg.multi_dot((self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
100
+
101
+ return mean, covariance
102
+
103
+ def project(self, mean: np.ndarray, covariance: np.ndarray):
104
+ """
105
+ Project state distribution to measurement space.
106
+
107
+ Args:
108
+ mean (np.ndarray): The state's mean vector (8 dimensional array).
109
+ covariance (np.ndarray): The state's covariance matrix (8x8 dimensional).
110
+
111
+ Returns:
112
+ mean (np.ndarray): Projected mean of the given state estimate.
113
+ covariance (np.ndarray): Projected covariance matrix of the given state estimate.
114
+ """
115
+ std = [
116
+ self._std_weight_position * mean[3],
117
+ self._std_weight_position * mean[3],
118
+ 1e-1,
119
+ self._std_weight_position * mean[3],
120
+ ]
121
+ innovation_cov = np.diag(np.square(std))
122
+
123
+ mean = np.dot(self._update_mat, mean)
124
+ covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T))
125
+ return mean, covariance + innovation_cov
126
+
127
+ def multi_predict(self, mean: np.ndarray, covariance: np.ndarray):
128
+ """
129
+ Run Kalman filter prediction step for multiple object states (Vectorized version).
130
+
131
+ Args:
132
+ mean (np.ndarray): The Nx8 dimensional mean matrix of the object states at the previous time step.
133
+ covariance (np.ndarray): The Nx8x8 covariance matrix of the object states at the previous time step.
134
+
135
+ Returns:
136
+ mean (np.ndarray): Mean matrix of the predicted states with shape (N, 8).
137
+ covariance (np.ndarray): Covariance matrix of the predicted states with shape (N, 8, 8).
138
+ """
139
+ std_pos = [
140
+ self._std_weight_position * mean[:, 3],
141
+ self._std_weight_position * mean[:, 3],
142
+ 1e-2 * np.ones_like(mean[:, 3]),
143
+ self._std_weight_position * mean[:, 3],
144
+ ]
145
+ std_vel = [
146
+ self._std_weight_velocity * mean[:, 3],
147
+ self._std_weight_velocity * mean[:, 3],
148
+ 1e-5 * np.ones_like(mean[:, 3]),
149
+ self._std_weight_velocity * mean[:, 3],
150
+ ]
151
+ sqr = np.square(np.r_[std_pos, std_vel]).T
152
+
153
+ motion_cov = [np.diag(sqr[i]) for i in range(len(mean))]
154
+ motion_cov = np.asarray(motion_cov)
155
+
156
+ mean = np.dot(mean, self._motion_mat.T)
157
+ left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2))
158
+ covariance = np.dot(left, self._motion_mat.T) + motion_cov
159
+
160
+ return mean, covariance
161
+
162
+ def update(self, mean: np.ndarray, covariance: np.ndarray, measurement: np.ndarray):
163
+ """
164
+ Run Kalman filter correction step.
165
+
166
+ Args:
167
+ mean (np.ndarray): The predicted state's mean vector (8 dimensional).
168
+ covariance (np.ndarray): The state's covariance matrix (8x8 dimensional).
169
+ measurement (np.ndarray): The 4 dimensional measurement vector (x, y, a, h), where (x, y) is the center
170
+ position, a the aspect ratio, and h the height of the bounding box.
171
+
172
+ Returns:
173
+ new_mean (np.ndarray): Measurement-corrected state mean.
174
+ new_covariance (np.ndarray): Measurement-corrected state covariance.
175
+ """
176
+ projected_mean, projected_cov = self.project(mean, covariance)
177
+
178
+ chol_factor, lower = scipy.linalg.cho_factor(projected_cov, lower=True, check_finite=False)
179
+ kalman_gain = scipy.linalg.cho_solve(
180
+ (chol_factor, lower), np.dot(covariance, self._update_mat.T).T, check_finite=False
181
+ ).T
182
+ innovation = measurement - projected_mean
183
+
184
+ new_mean = mean + np.dot(innovation, kalman_gain.T)
185
+ new_covariance = covariance - np.linalg.multi_dot((kalman_gain, projected_cov, kalman_gain.T))
186
+ return new_mean, new_covariance
187
+
188
+ def gating_distance(
189
+ self,
190
+ mean: np.ndarray,
191
+ covariance: np.ndarray,
192
+ measurements: np.ndarray,
193
+ only_position: bool = False,
194
+ metric: str = "maha",
195
+ ) -> np.ndarray:
196
+ """
197
+ Compute gating distance between state distribution and measurements.
198
+
199
+ Args:
200
+ mean (np.ndarray): Mean vector over the state distribution (8 dimensional).
201
+ covariance (np.ndarray): Covariance of the state distribution (8x8 dimensional).
202
+ measurements (np.ndarray): An (N, 4) matrix of N measurements, each in format (x, y, a, h) where (x, y) is the
203
+ bounding box center position, a the aspect ratio, and h the height.
204
+ only_position (bool, optional): If True, distance computation is done with respect to box center position only.
205
+ metric (str, optional): The metric to use for calculating the distance. Options are 'gaussian' for the squared
206
+ Euclidean distance and 'maha' for the squared Mahalanobis distance.
207
+
208
+ Returns:
209
+ (np.ndarray): Returns an array of length N, where the i-th element contains the squared distance between
210
+ (mean, covariance) and `measurements[i]`.
211
+ """
212
+ mean, covariance = self.project(mean, covariance)
213
+ if only_position:
214
+ mean, covariance = mean[:2], covariance[:2, :2]
215
+ measurements = measurements[:, :2]
216
+
217
+ d = measurements - mean
218
+ if metric == "gaussian":
219
+ return np.sum(d * d, axis=1)
220
+ elif metric == "maha":
221
+ cholesky_factor = np.linalg.cholesky(covariance)
222
+ z = scipy.linalg.solve_triangular(cholesky_factor, d.T, lower=True, check_finite=False, overwrite_b=True)
223
+ return np.sum(z * z, axis=0) # square maha
224
+ else:
225
+ raise ValueError("Invalid distance metric")
226
+
227
+
228
+ class KalmanFilterXYWH(KalmanFilterXYAH):
229
+ """
230
+ A KalmanFilterXYWH class for tracking bounding boxes in image space using a Kalman filter.
231
+
232
+ Implements a Kalman filter for tracking bounding boxes with state space (x, y, w, h, vx, vy, vw, vh), where
233
+ (x, y) is the center position, w is the width, h is the height, and vx, vy, vw, vh are their respective velocities.
234
+ The object motion follows a constant velocity model, and the bounding box location (x, y, w, h) is taken as a direct
235
+ observation of the state space (linear observation model).
236
+ """
237
+
238
+ def initiate(self, measurement: np.ndarray):
239
+ """
240
+ Create track from unassociated measurement.
241
+
242
+ Args:
243
+ measurement (np.ndarray): Bounding box coordinates (x, y, w, h) with center position (x, y), width, and height.
244
+
245
+ Returns:
246
+ mean (np.ndarray): Mean vector (8 dimensional) of the new track. Unobserved velocities are initialized to 0 mean.
247
+ covariance (np.ndarray): Covariance matrix (8x8 dimensional) of the new track.
248
+ """
249
+ mean_pos = measurement
250
+ mean_vel = np.zeros_like(mean_pos)
251
+ mean = np.r_[mean_pos, mean_vel]
252
+
253
+ std = [
254
+ 2 * self._std_weight_position * measurement[2],
255
+ 2 * self._std_weight_position * measurement[3],
256
+ 2 * self._std_weight_position * measurement[2],
257
+ 2 * self._std_weight_position * measurement[3],
258
+ 10 * self._std_weight_velocity * measurement[2],
259
+ 10 * self._std_weight_velocity * measurement[3],
260
+ 10 * self._std_weight_velocity * measurement[2],
261
+ 10 * self._std_weight_velocity * measurement[3],
262
+ ]
263
+ covariance = np.diag(np.square(std))
264
+ return mean, covariance
265
+
266
+ def predict(self, mean: np.ndarray, covariance: np.ndarray):
267
+ """
268
+ Run Kalman filter prediction step.
269
+
270
+ Args:
271
+ mean (np.ndarray): The 8-dimensional mean vector of the object state at the previous time step.
272
+ covariance (np.ndarray): The 8x8-dimensional covariance matrix of the object state at the previous time step.
273
+
274
+ Returns:
275
+ mean (np.ndarray): Mean vector of the predicted state. Unobserved velocities are initialized to 0 mean.
276
+ covariance (np.ndarray): Covariance matrix of the predicted state.
277
+ """
278
+ std_pos = [
279
+ self._std_weight_position * mean[2],
280
+ self._std_weight_position * mean[3],
281
+ self._std_weight_position * mean[2],
282
+ self._std_weight_position * mean[3],
283
+ ]
284
+ std_vel = [
285
+ self._std_weight_velocity * mean[2],
286
+ self._std_weight_velocity * mean[3],
287
+ self._std_weight_velocity * mean[2],
288
+ self._std_weight_velocity * mean[3],
289
+ ]
290
+ motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
291
+
292
+ mean = np.dot(mean, self._motion_mat.T)
293
+ covariance = np.linalg.multi_dot((self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
294
+
295
+ return mean, covariance
296
+
297
+ def project(self, mean: np.ndarray, covariance: np.ndarray):
298
+ """
299
+ Project state distribution to measurement space.
300
+
301
+ Args:
302
+ mean (np.ndarray): The state's mean vector (8 dimensional array).
303
+ covariance (np.ndarray): The state's covariance matrix (8x8 dimensional).
304
+
305
+ Returns:
306
+ mean (np.ndarray): Projected mean of the given state estimate.
307
+ covariance (np.ndarray): Projected covariance matrix of the given state estimate.
308
+ """
309
+ std = [
310
+ self._std_weight_position * mean[2],
311
+ self._std_weight_position * mean[3],
312
+ self._std_weight_position * mean[2],
313
+ self._std_weight_position * mean[3],
314
+ ]
315
+ innovation_cov = np.diag(np.square(std))
316
+
317
+ mean = np.dot(self._update_mat, mean)
318
+ covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T))
319
+ return mean, covariance + innovation_cov
320
+
321
+ def multi_predict(self, mean: np.ndarray, covariance: np.ndarray):
322
+ """
323
+ Run Kalman filter prediction step (Vectorized version).
324
+
325
+ Args:
326
+ mean (np.ndarray): The Nx8 dimensional mean matrix of the object states at the previous time step.
327
+ covariance (np.ndarray): The Nx8x8 covariance matrix of the object states at the previous time step.
328
+
329
+ Returns:
330
+ mean (np.ndarray): Mean matrix of the predicted states with shape (N, 8).
331
+ covariance (np.ndarray): Covariance matrix of the predicted states with shape (N, 8, 8).
332
+ """
333
+ std_pos = [
334
+ self._std_weight_position * mean[:, 2],
335
+ self._std_weight_position * mean[:, 3],
336
+ self._std_weight_position * mean[:, 2],
337
+ self._std_weight_position * mean[:, 3],
338
+ ]
339
+ std_vel = [
340
+ self._std_weight_velocity * mean[:, 2],
341
+ self._std_weight_velocity * mean[:, 3],
342
+ self._std_weight_velocity * mean[:, 2],
343
+ self._std_weight_velocity * mean[:, 3],
344
+ ]
345
+ sqr = np.square(np.r_[std_pos, std_vel]).T
346
+
347
+ motion_cov = [np.diag(sqr[i]) for i in range(len(mean))]
348
+ motion_cov = np.asarray(motion_cov)
349
+
350
+ mean = np.dot(mean, self._motion_mat.T)
351
+ left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2))
352
+ covariance = np.dot(left, self._motion_mat.T) + motion_cov
353
+
354
+ return mean, covariance
355
+
356
+ def update(self, mean: np.ndarray, covariance: np.ndarray, measurement: np.ndarray):
357
+ """
358
+ Run Kalman filter correction step.
359
+
360
+ Args:
361
+ mean (np.ndarray): The predicted state's mean vector (8 dimensional).
362
+ covariance (np.ndarray): The state's covariance matrix (8x8 dimensional).
363
+ measurement (np.ndarray): The 4 dimensional measurement vector (x, y, w, h), where (x, y) is the center
364
+ position, w the width, and h the height of the bounding box.
365
+
366
+ Returns:
367
+ new_mean (np.ndarray): Measurement-corrected state mean.
368
+ new_covariance (np.ndarray): Measurement-corrected state covariance.
369
+ """
370
+ return super().update(mean, covariance, measurement)