matrice-analytics 0.1.60__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. matrice_analytics/__init__.py +28 -0
  2. matrice_analytics/boundary_drawing_internal/README.md +305 -0
  3. matrice_analytics/boundary_drawing_internal/__init__.py +45 -0
  4. matrice_analytics/boundary_drawing_internal/boundary_drawing_internal.py +1207 -0
  5. matrice_analytics/boundary_drawing_internal/boundary_drawing_tool.py +429 -0
  6. matrice_analytics/boundary_drawing_internal/boundary_tool_template.html +1036 -0
  7. matrice_analytics/boundary_drawing_internal/data/.gitignore +12 -0
  8. matrice_analytics/boundary_drawing_internal/example_usage.py +206 -0
  9. matrice_analytics/boundary_drawing_internal/usage/README.md +110 -0
  10. matrice_analytics/boundary_drawing_internal/usage/boundary_drawer_launcher.py +102 -0
  11. matrice_analytics/boundary_drawing_internal/usage/simple_boundary_launcher.py +107 -0
  12. matrice_analytics/post_processing/README.md +455 -0
  13. matrice_analytics/post_processing/__init__.py +732 -0
  14. matrice_analytics/post_processing/advanced_tracker/README.md +650 -0
  15. matrice_analytics/post_processing/advanced_tracker/__init__.py +17 -0
  16. matrice_analytics/post_processing/advanced_tracker/base.py +99 -0
  17. matrice_analytics/post_processing/advanced_tracker/config.py +77 -0
  18. matrice_analytics/post_processing/advanced_tracker/kalman_filter.py +370 -0
  19. matrice_analytics/post_processing/advanced_tracker/matching.py +195 -0
  20. matrice_analytics/post_processing/advanced_tracker/strack.py +230 -0
  21. matrice_analytics/post_processing/advanced_tracker/tracker.py +367 -0
  22. matrice_analytics/post_processing/config.py +146 -0
  23. matrice_analytics/post_processing/core/__init__.py +63 -0
  24. matrice_analytics/post_processing/core/base.py +704 -0
  25. matrice_analytics/post_processing/core/config.py +3291 -0
  26. matrice_analytics/post_processing/core/config_utils.py +925 -0
  27. matrice_analytics/post_processing/face_reg/__init__.py +43 -0
  28. matrice_analytics/post_processing/face_reg/compare_similarity.py +556 -0
  29. matrice_analytics/post_processing/face_reg/embedding_manager.py +950 -0
  30. matrice_analytics/post_processing/face_reg/face_recognition.py +2234 -0
  31. matrice_analytics/post_processing/face_reg/face_recognition_client.py +606 -0
  32. matrice_analytics/post_processing/face_reg/people_activity_logging.py +321 -0
  33. matrice_analytics/post_processing/ocr/__init__.py +0 -0
  34. matrice_analytics/post_processing/ocr/easyocr_extractor.py +250 -0
  35. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/__init__.py +9 -0
  36. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/__init__.py +4 -0
  37. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/cli.py +33 -0
  38. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/dataset_stats.py +139 -0
  39. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/export.py +398 -0
  40. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/train.py +447 -0
  41. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/utils.py +129 -0
  42. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/valid.py +93 -0
  43. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/validate_dataset.py +240 -0
  44. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/visualize_augmentation.py +176 -0
  45. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/visualize_predictions.py +96 -0
  46. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/__init__.py +3 -0
  47. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/process.py +246 -0
  48. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/types.py +60 -0
  49. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/utils.py +87 -0
  50. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/__init__.py +3 -0
  51. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/config.py +82 -0
  52. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/hub.py +141 -0
  53. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/plate_recognizer.py +323 -0
  54. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/py.typed +0 -0
  55. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/__init__.py +0 -0
  56. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/data/__init__.py +0 -0
  57. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/data/augmentation.py +101 -0
  58. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/data/dataset.py +97 -0
  59. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/__init__.py +0 -0
  60. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/config.py +114 -0
  61. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/layers.py +553 -0
  62. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/loss.py +55 -0
  63. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/metric.py +86 -0
  64. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/model_builders.py +95 -0
  65. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/model_schema.py +395 -0
  66. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/utilities/__init__.py +0 -0
  67. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/utilities/backend_utils.py +38 -0
  68. matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/utilities/utils.py +214 -0
  69. matrice_analytics/post_processing/ocr/postprocessing.py +270 -0
  70. matrice_analytics/post_processing/ocr/preprocessing.py +52 -0
  71. matrice_analytics/post_processing/post_processor.py +1175 -0
  72. matrice_analytics/post_processing/test_cases/__init__.py +1 -0
  73. matrice_analytics/post_processing/test_cases/run_tests.py +143 -0
  74. matrice_analytics/post_processing/test_cases/test_advanced_customer_service.py +841 -0
  75. matrice_analytics/post_processing/test_cases/test_basic_counting_tracking.py +523 -0
  76. matrice_analytics/post_processing/test_cases/test_comprehensive.py +531 -0
  77. matrice_analytics/post_processing/test_cases/test_config.py +852 -0
  78. matrice_analytics/post_processing/test_cases/test_customer_service.py +585 -0
  79. matrice_analytics/post_processing/test_cases/test_data_generators.py +583 -0
  80. matrice_analytics/post_processing/test_cases/test_people_counting.py +510 -0
  81. matrice_analytics/post_processing/test_cases/test_processor.py +524 -0
  82. matrice_analytics/post_processing/test_cases/test_usecases.py +165 -0
  83. matrice_analytics/post_processing/test_cases/test_utilities.py +356 -0
  84. matrice_analytics/post_processing/test_cases/test_utils.py +743 -0
  85. matrice_analytics/post_processing/usecases/Histopathological_Cancer_Detection_img.py +604 -0
  86. matrice_analytics/post_processing/usecases/__init__.py +267 -0
  87. matrice_analytics/post_processing/usecases/abandoned_object_detection.py +797 -0
  88. matrice_analytics/post_processing/usecases/advanced_customer_service.py +1601 -0
  89. matrice_analytics/post_processing/usecases/age_detection.py +842 -0
  90. matrice_analytics/post_processing/usecases/age_gender_detection.py +1085 -0
  91. matrice_analytics/post_processing/usecases/anti_spoofing_detection.py +656 -0
  92. matrice_analytics/post_processing/usecases/assembly_line_detection.py +841 -0
  93. matrice_analytics/post_processing/usecases/banana_defect_detection.py +624 -0
  94. matrice_analytics/post_processing/usecases/basic_counting_tracking.py +667 -0
  95. matrice_analytics/post_processing/usecases/blood_cancer_detection_img.py +881 -0
  96. matrice_analytics/post_processing/usecases/car_damage_detection.py +834 -0
  97. matrice_analytics/post_processing/usecases/car_part_segmentation.py +946 -0
  98. matrice_analytics/post_processing/usecases/car_service.py +1601 -0
  99. matrice_analytics/post_processing/usecases/cardiomegaly_classification.py +864 -0
  100. matrice_analytics/post_processing/usecases/cell_microscopy_segmentation.py +897 -0
  101. matrice_analytics/post_processing/usecases/chicken_pose_detection.py +648 -0
  102. matrice_analytics/post_processing/usecases/child_monitoring.py +814 -0
  103. matrice_analytics/post_processing/usecases/color/clip.py +660 -0
  104. matrice_analytics/post_processing/usecases/color/clip_processor/merges.txt +48895 -0
  105. matrice_analytics/post_processing/usecases/color/clip_processor/preprocessor_config.json +28 -0
  106. matrice_analytics/post_processing/usecases/color/clip_processor/special_tokens_map.json +30 -0
  107. matrice_analytics/post_processing/usecases/color/clip_processor/tokenizer.json +245079 -0
  108. matrice_analytics/post_processing/usecases/color/clip_processor/tokenizer_config.json +32 -0
  109. matrice_analytics/post_processing/usecases/color/clip_processor/vocab.json +1 -0
  110. matrice_analytics/post_processing/usecases/color/color_map_utils.py +70 -0
  111. matrice_analytics/post_processing/usecases/color/color_mapper.py +468 -0
  112. matrice_analytics/post_processing/usecases/color_detection.py +1936 -0
  113. matrice_analytics/post_processing/usecases/color_map_utils.py +70 -0
  114. matrice_analytics/post_processing/usecases/concrete_crack_detection.py +827 -0
  115. matrice_analytics/post_processing/usecases/crop_weed_detection.py +781 -0
  116. matrice_analytics/post_processing/usecases/customer_service.py +1008 -0
  117. matrice_analytics/post_processing/usecases/defect_detection_products.py +936 -0
  118. matrice_analytics/post_processing/usecases/distracted_driver_detection.py +822 -0
  119. matrice_analytics/post_processing/usecases/drone_traffic_monitoring.py +585 -0
  120. matrice_analytics/post_processing/usecases/drowsy_driver_detection.py +829 -0
  121. matrice_analytics/post_processing/usecases/dwell_detection.py +829 -0
  122. matrice_analytics/post_processing/usecases/emergency_vehicle_detection.py +827 -0
  123. matrice_analytics/post_processing/usecases/face_emotion.py +813 -0
  124. matrice_analytics/post_processing/usecases/face_recognition.py +827 -0
  125. matrice_analytics/post_processing/usecases/fashion_detection.py +835 -0
  126. matrice_analytics/post_processing/usecases/field_mapping.py +902 -0
  127. matrice_analytics/post_processing/usecases/fire_detection.py +1146 -0
  128. matrice_analytics/post_processing/usecases/flare_analysis.py +836 -0
  129. matrice_analytics/post_processing/usecases/flower_segmentation.py +1006 -0
  130. matrice_analytics/post_processing/usecases/gas_leak_detection.py +837 -0
  131. matrice_analytics/post_processing/usecases/gender_detection.py +832 -0
  132. matrice_analytics/post_processing/usecases/human_activity_recognition.py +871 -0
  133. matrice_analytics/post_processing/usecases/intrusion_detection.py +1672 -0
  134. matrice_analytics/post_processing/usecases/leaf.py +821 -0
  135. matrice_analytics/post_processing/usecases/leaf_disease.py +840 -0
  136. matrice_analytics/post_processing/usecases/leak_detection.py +837 -0
  137. matrice_analytics/post_processing/usecases/license_plate_detection.py +1188 -0
  138. matrice_analytics/post_processing/usecases/license_plate_monitoring.py +1781 -0
  139. matrice_analytics/post_processing/usecases/litter_monitoring.py +717 -0
  140. matrice_analytics/post_processing/usecases/mask_detection.py +869 -0
  141. matrice_analytics/post_processing/usecases/natural_disaster.py +907 -0
  142. matrice_analytics/post_processing/usecases/parking.py +787 -0
  143. matrice_analytics/post_processing/usecases/parking_space_detection.py +822 -0
  144. matrice_analytics/post_processing/usecases/pcb_defect_detection.py +888 -0
  145. matrice_analytics/post_processing/usecases/pedestrian_detection.py +808 -0
  146. matrice_analytics/post_processing/usecases/people_counting.py +706 -0
  147. matrice_analytics/post_processing/usecases/people_counting_bckp.py +1683 -0
  148. matrice_analytics/post_processing/usecases/people_tracking.py +1842 -0
  149. matrice_analytics/post_processing/usecases/pipeline_detection.py +605 -0
  150. matrice_analytics/post_processing/usecases/plaque_segmentation_img.py +874 -0
  151. matrice_analytics/post_processing/usecases/pothole_segmentation.py +915 -0
  152. matrice_analytics/post_processing/usecases/ppe_compliance.py +645 -0
  153. matrice_analytics/post_processing/usecases/price_tag_detection.py +822 -0
  154. matrice_analytics/post_processing/usecases/proximity_detection.py +1901 -0
  155. matrice_analytics/post_processing/usecases/road_lane_detection.py +623 -0
  156. matrice_analytics/post_processing/usecases/road_traffic_density.py +832 -0
  157. matrice_analytics/post_processing/usecases/road_view_segmentation.py +915 -0
  158. matrice_analytics/post_processing/usecases/shelf_inventory_detection.py +583 -0
  159. matrice_analytics/post_processing/usecases/shoplifting_detection.py +822 -0
  160. matrice_analytics/post_processing/usecases/shopping_cart_analysis.py +899 -0
  161. matrice_analytics/post_processing/usecases/skin_cancer_classification_img.py +864 -0
  162. matrice_analytics/post_processing/usecases/smoker_detection.py +833 -0
  163. matrice_analytics/post_processing/usecases/solar_panel.py +810 -0
  164. matrice_analytics/post_processing/usecases/suspicious_activity_detection.py +1030 -0
  165. matrice_analytics/post_processing/usecases/template_usecase.py +380 -0
  166. matrice_analytics/post_processing/usecases/theft_detection.py +648 -0
  167. matrice_analytics/post_processing/usecases/traffic_sign_monitoring.py +724 -0
  168. matrice_analytics/post_processing/usecases/underground_pipeline_defect_detection.py +775 -0
  169. matrice_analytics/post_processing/usecases/underwater_pollution_detection.py +842 -0
  170. matrice_analytics/post_processing/usecases/vehicle_monitoring.py +1029 -0
  171. matrice_analytics/post_processing/usecases/warehouse_object_segmentation.py +899 -0
  172. matrice_analytics/post_processing/usecases/waterbody_segmentation.py +923 -0
  173. matrice_analytics/post_processing/usecases/weapon_detection.py +771 -0
  174. matrice_analytics/post_processing/usecases/weld_defect_detection.py +615 -0
  175. matrice_analytics/post_processing/usecases/wildlife_monitoring.py +898 -0
  176. matrice_analytics/post_processing/usecases/windmill_maintenance.py +834 -0
  177. matrice_analytics/post_processing/usecases/wound_segmentation.py +856 -0
  178. matrice_analytics/post_processing/utils/__init__.py +150 -0
  179. matrice_analytics/post_processing/utils/advanced_counting_utils.py +400 -0
  180. matrice_analytics/post_processing/utils/advanced_helper_utils.py +317 -0
  181. matrice_analytics/post_processing/utils/advanced_tracking_utils.py +461 -0
  182. matrice_analytics/post_processing/utils/alerting_utils.py +213 -0
  183. matrice_analytics/post_processing/utils/category_mapping_utils.py +94 -0
  184. matrice_analytics/post_processing/utils/color_utils.py +592 -0
  185. matrice_analytics/post_processing/utils/counting_utils.py +182 -0
  186. matrice_analytics/post_processing/utils/filter_utils.py +261 -0
  187. matrice_analytics/post_processing/utils/format_utils.py +293 -0
  188. matrice_analytics/post_processing/utils/geometry_utils.py +300 -0
  189. matrice_analytics/post_processing/utils/smoothing_utils.py +358 -0
  190. matrice_analytics/post_processing/utils/tracking_utils.py +234 -0
  191. matrice_analytics/py.typed +0 -0
  192. matrice_analytics-0.1.60.dist-info/METADATA +481 -0
  193. matrice_analytics-0.1.60.dist-info/RECORD +196 -0
  194. matrice_analytics-0.1.60.dist-info/WHEEL +5 -0
  195. matrice_analytics-0.1.60.dist-info/licenses/LICENSE.txt +21 -0
  196. matrice_analytics-0.1.60.dist-info/top_level.txt +1 -0
@@ -0,0 +1,195 @@
1
+ """
2
+ Matching utilities for track association.
3
+
4
+ This module provides utilities for matching tracks with detections,
5
+ including IoU distance calculation and linear assignment.
6
+ """
7
+
8
+ import numpy as np
9
+ import scipy
10
+ from scipy.spatial.distance import cdist
11
+
12
+ try:
13
+ import lap # for linear_assignment
14
+ assert lap.__version__ # verify package is not directory
15
+ except (ImportError, AssertionError, AttributeError):
16
+ # Fallback to scipy if lap is not available
17
+ lap = None
18
+
19
+
20
+ def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = True):
21
+ """
22
+ Perform linear assignment using either the scipy or lap.lapjv method.
23
+
24
+ Args:
25
+ cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M).
26
+ thresh (float): Threshold for considering an assignment valid.
27
+ use_lap (bool): Use lap.lapjv for the assignment. If False, scipy.optimize.linear_sum_assignment is used.
28
+
29
+ Returns:
30
+ matched_indices (np.ndarray): Array of matched indices of shape (K, 2), where K is the number of matches.
31
+ unmatched_a (np.ndarray): Array of unmatched indices from the first set, with shape (L,).
32
+ unmatched_b (np.ndarray): Array of unmatched indices from the second set, with shape (M,).
33
+ """
34
+ if cost_matrix.size == 0:
35
+ return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1]))
36
+
37
+ if use_lap and lap is not None:
38
+ # Use lap.lapjv
39
+ _, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)
40
+ matches = [[ix, mx] for ix, mx in enumerate(x) if mx >= 0]
41
+ unmatched_a = np.where(x < 0)[0]
42
+ unmatched_b = np.where(y < 0)[0]
43
+ else:
44
+ # Use scipy.optimize.linear_sum_assignment
45
+ x, y = scipy.optimize.linear_sum_assignment(cost_matrix) # row x, col y
46
+ matches = np.asarray([[x[i], y[i]] for i in range(len(x)) if cost_matrix[x[i], y[i]] <= thresh])
47
+ if len(matches) == 0:
48
+ unmatched_a = list(np.arange(cost_matrix.shape[0]))
49
+ unmatched_b = list(np.arange(cost_matrix.shape[1]))
50
+ else:
51
+ unmatched_a = list(frozenset(np.arange(cost_matrix.shape[0])) - frozenset(matches[:, 0]))
52
+ unmatched_b = list(frozenset(np.arange(cost_matrix.shape[1])) - frozenset(matches[:, 1]))
53
+
54
+ return matches, unmatched_a, unmatched_b
55
+
56
+
57
+ def bbox_ioa(box1: np.ndarray, box2: np.ndarray, iou: bool = True) -> np.ndarray:
58
+ """
59
+ Calculate the intersection over area of box1, box2. Boxes are in x1y1x2y2 format.
60
+
61
+ Args:
62
+ box1 (np.ndarray): First set of boxes (N, 4)
63
+ box2 (np.ndarray): Second set of boxes (M, 4)
64
+ iou (bool): If True, calculate IoU, otherwise calculate IoA
65
+
66
+ Returns:
67
+ np.ndarray: IoU/IoA matrix of shape (N, M)
68
+ """
69
+ # Returns the intersection over box1 area by default
70
+ # box1: (N, 4), box2: (M, 4)
71
+ N = box1.shape[0]
72
+ M = box2.shape[0]
73
+
74
+ # Calculate intersection
75
+ tl = np.maximum(box1[:, None, :2], box2[:, :2]) # (N, M, 2)
76
+ br = np.minimum(box1[:, None, 2:], box2[:, 2:]) # (N, M, 2)
77
+ wh = np.maximum(0, br - tl) # (N, M, 2)
78
+ inter = wh[:, :, 0] * wh[:, :, 1] # (N, M)
79
+
80
+ # Calculate areas
81
+ area1 = (box1[:, 2] - box1[:, 0]) * (box1[:, 3] - box1[:, 1]) # (N,)
82
+ area2 = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1]) # (M,)
83
+
84
+ if iou:
85
+ union = area1[:, None] + area2 - inter
86
+ return inter / union
87
+ else:
88
+ return inter / area1[:, None]
89
+
90
+
91
+ def iou_distance(atracks: list, btracks: list) -> np.ndarray:
92
+ """
93
+ Compute cost based on Intersection over Union (IoU) between tracks.
94
+
95
+ Args:
96
+ atracks (List[STrack] or List[np.ndarray]): List of tracks 'a' or bounding boxes.
97
+ btracks (List[STrack] or List[np.ndarray]): List of tracks 'b' or bounding boxes.
98
+
99
+ Returns:
100
+ (np.ndarray): Cost matrix computed based on IoU with shape (len(atracks), len(btracks)).
101
+ """
102
+ if atracks and isinstance(atracks[0], np.ndarray) or btracks and isinstance(btracks[0], np.ndarray):
103
+ atlbrs = atracks
104
+ btlbrs = btracks
105
+ else:
106
+ # Extract bounding boxes from track objects
107
+ atlbrs = []
108
+ btlbrs = []
109
+ for track in atracks:
110
+ if hasattr(track, 'xyxy'):
111
+ atlbrs.append(track.xyxy)
112
+ elif hasattr(track, 'tlwh'):
113
+ tlwh = track.tlwh
114
+ atlbrs.append([tlwh[0], tlwh[1], tlwh[0] + tlwh[2], tlwh[1] + tlwh[3]])
115
+ else:
116
+ atlbrs.append([0, 0, 0, 0])
117
+
118
+ for track in btracks:
119
+ if hasattr(track, 'xyxy'):
120
+ btlbrs.append(track.xyxy)
121
+ elif hasattr(track, 'tlwh'):
122
+ tlwh = track.tlwh
123
+ btlbrs.append([tlwh[0], tlwh[1], tlwh[0] + tlwh[2], tlwh[1] + tlwh[3]])
124
+ else:
125
+ btlbrs.append([0, 0, 0, 0])
126
+
127
+ ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32)
128
+ if len(atlbrs) and len(btlbrs):
129
+ ious = bbox_ioa(
130
+ np.ascontiguousarray(atlbrs, dtype=np.float32),
131
+ np.ascontiguousarray(btlbrs, dtype=np.float32),
132
+ iou=True,
133
+ )
134
+ return 1 - ious # cost matrix
135
+
136
+
137
+ def embedding_distance(tracks: list, detections: list, metric: str = "cosine") -> np.ndarray:
138
+ """
139
+ Compute distance between tracks and detections based on embeddings.
140
+
141
+ Args:
142
+ tracks (List[STrack] or List[np.ndarray]): List of tracks, where each track contains embedding features.
143
+ detections (List[BaseTrack]): List of detections, where each detection contains embedding features.
144
+ metric (str): Metric for distance computation. Supported metrics include 'cosine', 'euclidean', etc.
145
+
146
+ Returns:
147
+ (np.ndarray): Cost matrix computed based on embeddings with shape (N, M), where N is the number of tracks
148
+ and M is the number of detections.
149
+ """
150
+ cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32)
151
+ if cost_matrix.size == 0:
152
+ return cost_matrix
153
+
154
+ # Extract features from detections
155
+ det_features = []
156
+ for det in detections:
157
+ if hasattr(det, 'curr_feat') and det.curr_feat is not None:
158
+ det_features.append(det.curr_feat)
159
+ else:
160
+ det_features.append(np.zeros(128)) # Default feature size
161
+
162
+ det_features = np.asarray(det_features, dtype=np.float32)
163
+
164
+ # Extract features from tracks
165
+ track_features = []
166
+ for track in tracks:
167
+ if hasattr(track, 'smooth_feat') and track.smooth_feat is not None:
168
+ track_features.append(track.smooth_feat)
169
+ else:
170
+ track_features.append(np.zeros(128)) # Default feature size
171
+
172
+ track_features = np.asarray(track_features, dtype=np.float32)
173
+
174
+ cost_matrix = np.maximum(0.0, cdist(track_features, det_features, metric))
175
+ return cost_matrix
176
+
177
+
178
+ def fuse_score(cost_matrix: np.ndarray, detections: list) -> np.ndarray:
179
+ """
180
+ Fuse cost matrix with detection scores to produce a single similarity matrix.
181
+
182
+ Args:
183
+ cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M).
184
+ detections (List[BaseTrack]): List of detections, each containing a score attribute.
185
+
186
+ Returns:
187
+ (np.ndarray): Fused similarity matrix with shape (N, M).
188
+ """
189
+ if cost_matrix.size == 0:
190
+ return cost_matrix
191
+ iou_sim = 1 - cost_matrix
192
+ det_scores = np.array([det.score for det in detections])
193
+ det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0)
194
+ fuse_sim = iou_sim * det_scores
195
+ return 1 - fuse_sim # fuse_cost
@@ -0,0 +1,230 @@
1
+ """
2
+ STrack class for single object tracking.
3
+
4
+ This module provides the STrack class that represents a single tracked object
5
+ with Kalman filtering for state estimation.
6
+ """
7
+
8
+ from typing import Any, List
9
+ import numpy as np
10
+
11
+ from .base import BaseTrack, TrackState
12
+ from .kalman_filter import KalmanFilterXYAH
13
+
14
+
15
+ def xywh2ltwh(xywh: List[float]) -> List[float]:
16
+ """
17
+ Convert bounding box from center format (x, y, w, h) to top-left format (x, y, w, h).
18
+
19
+ Args:
20
+ xywh (List[float]): Bounding box in center format [x, y, w, h]
21
+
22
+ Returns:
23
+ List[float]: Bounding box in top-left format [x, y, w, h]
24
+ """
25
+ x, y, w, h = xywh
26
+ return [x - w/2, y - h/2, w, h]
27
+
28
+
29
+ class STrack(BaseTrack):
30
+ """
31
+ Single object tracking representation that uses Kalman filtering for state estimation.
32
+
33
+ This class is responsible for storing all the information regarding individual tracklets and performs state updates
34
+ and predictions based on Kalman filter.
35
+
36
+ Attributes:
37
+ shared_kalman (KalmanFilterXYAH): Shared Kalman filter used across all STrack instances for prediction.
38
+ _tlwh (np.ndarray): Private attribute to store top-left corner coordinates and width and height of bounding box.
39
+ kalman_filter (KalmanFilterXYAH): Instance of Kalman filter used for this particular object track.
40
+ mean (np.ndarray): Mean state estimate vector.
41
+ covariance (np.ndarray): Covariance of state estimate.
42
+ is_activated (bool): Boolean flag indicating if the track has been activated.
43
+ score (float): Confidence score of the track.
44
+ tracklet_len (int): Length of the tracklet.
45
+ cls (Any): Class label for the object.
46
+ idx (int): Index or identifier for the object.
47
+ frame_id (int): Current frame ID.
48
+ start_frame (int): Frame where the object was first detected.
49
+ angle (float or None): Optional angle information for oriented bounding boxes.
50
+ """
51
+
52
+ shared_kalman = KalmanFilterXYAH()
53
+
54
+ def __init__(self, xywh: List[float], score: float, cls: Any):
55
+ """
56
+ Initialize a new STrack instance.
57
+
58
+ Args:
59
+ xywh (List[float]): Bounding box coordinates and dimensions in the format (x, y, w, h, [a], idx), where
60
+ (x, y) is the center, (w, h) are width and height, [a] is optional aspect ratio, and idx is the id.
61
+ score (float): Confidence score of the detection.
62
+ cls (Any): Class label for the detected object.
63
+ """
64
+ super().__init__()
65
+ # xywh+idx or xywha+idx
66
+ assert len(xywh) in {5, 6}, f"expected 5 or 6 values but got {len(xywh)}"
67
+ self._tlwh = np.asarray(xywh2ltwh(xywh[:4]), dtype=np.float32)
68
+ self.kalman_filter = None
69
+ self.mean, self.covariance = None, None
70
+ self.is_activated = False
71
+
72
+ self.score = score
73
+ self.tracklet_len = 0
74
+ self.cls = cls
75
+ self.idx = xywh[-1]
76
+ self.angle = xywh[4] if len(xywh) == 6 else None
77
+
78
+ def predict(self):
79
+ """Predict the next state (mean and covariance) of the object using the Kalman filter."""
80
+ mean_state = self.mean.copy()
81
+ if self.state != TrackState.Tracked:
82
+ mean_state[7] = 0
83
+ self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
84
+
85
+ @staticmethod
86
+ def multi_predict(stracks: List["STrack"]):
87
+ """Perform multi-object predictive tracking using Kalman filter for the provided list of STrack instances."""
88
+ if len(stracks) <= 0:
89
+ return
90
+ multi_mean = np.asarray([st.mean.copy() for st in stracks])
91
+ multi_covariance = np.asarray([st.covariance for st in stracks])
92
+ for i, st in enumerate(stracks):
93
+ if st.state != TrackState.Tracked:
94
+ multi_mean[i][7] = 0
95
+ multi_mean, multi_covariance = STrack.shared_kalman.multi_predict(multi_mean, multi_covariance)
96
+ for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
97
+ stracks[i].mean = mean
98
+ stracks[i].covariance = cov
99
+
100
+ @staticmethod
101
+ def multi_gmc(stracks: List["STrack"], H: np.ndarray = np.eye(2, 3)):
102
+ """Update state tracks positions and covariances using a homography matrix for multiple tracks."""
103
+ if len(stracks) > 0:
104
+ multi_mean = np.asarray([st.mean.copy() for st in stracks])
105
+ multi_covariance = np.asarray([st.covariance for st in stracks])
106
+
107
+ R = H[:2, :2]
108
+ R8x8 = np.kron(np.eye(4, dtype=float), R)
109
+ t = H[:2, 2]
110
+
111
+ for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
112
+ mean = R8x8.dot(mean)
113
+ mean[:2] += t
114
+ cov = R8x8.dot(cov).dot(R8x8.transpose())
115
+
116
+ stracks[i].mean = mean
117
+ stracks[i].covariance = cov
118
+
119
+ def activate(self, kalman_filter: KalmanFilterXYAH, frame_id: int):
120
+ """Activate a new tracklet using the provided Kalman filter and initialize its state and covariance."""
121
+ self.kalman_filter = kalman_filter
122
+ self.track_id = self.next_id()
123
+ self.mean, self.covariance = self.kalman_filter.initiate(self.convert_coords(self._tlwh))
124
+
125
+ self.tracklet_len = 0
126
+ self.state = TrackState.Tracked
127
+ if frame_id == 1:
128
+ self.is_activated = True
129
+ self.frame_id = frame_id
130
+ self.start_frame = frame_id
131
+
132
+ def re_activate(self, new_track: "STrack", frame_id: int, new_id: bool = False):
133
+ """Reactivate a previously lost track using new detection data and update its state and attributes."""
134
+ self.mean, self.covariance = self.kalman_filter.update(
135
+ self.mean, self.covariance, self.convert_coords(new_track.tlwh)
136
+ )
137
+ self.tracklet_len = 0
138
+ self.state = TrackState.Tracked
139
+ self.is_activated = True
140
+ self.frame_id = frame_id
141
+ if new_id:
142
+ self.track_id = self.next_id()
143
+ self.score = new_track.score
144
+ self.cls = new_track.cls
145
+ self.angle = new_track.angle
146
+ self.idx = new_track.idx
147
+
148
+ # CRITICAL FIX: Preserve original detection data for face recognition fields
149
+ if hasattr(new_track, 'original_detection'):
150
+ self.original_detection = new_track.original_detection
151
+
152
+ def update(self, new_track: "STrack", frame_id: int):
153
+ """
154
+ Update the state of a matched track.
155
+
156
+ Args:
157
+ new_track (STrack): The new track containing updated information.
158
+ frame_id (int): The ID of the current frame.
159
+ """
160
+ self.frame_id = frame_id
161
+ self.tracklet_len += 1
162
+
163
+ new_tlwh = new_track.tlwh
164
+ self.mean, self.covariance = self.kalman_filter.update(
165
+ self.mean, self.covariance, self.convert_coords(new_tlwh)
166
+ )
167
+ self.state = TrackState.Tracked
168
+ self.is_activated = True
169
+
170
+ self.score = new_track.score
171
+ self.cls = new_track.cls
172
+ self.angle = new_track.angle
173
+ self.idx = new_track.idx
174
+
175
+ # CRITICAL FIX: Preserve original detection data for face recognition fields
176
+ if hasattr(new_track, 'original_detection'):
177
+ self.original_detection = new_track.original_detection
178
+
179
+ def convert_coords(self, tlwh: np.ndarray) -> np.ndarray:
180
+ """Convert a bounding box's top-left-width-height format to its x-y-aspect-height equivalent."""
181
+ return self.tlwh_to_xyah(tlwh)
182
+
183
+ @property
184
+ def tlwh(self) -> np.ndarray:
185
+ """Get the bounding box in top-left-width-height format from the current state estimate."""
186
+ if self.mean is None:
187
+ return self._tlwh.copy()
188
+ ret = self.mean[:4].copy()
189
+ ret[2] *= ret[3]
190
+ ret[:2] -= ret[2:] / 2
191
+ return ret
192
+
193
+ @property
194
+ def xyxy(self) -> np.ndarray:
195
+ """Convert bounding box from (top left x, top left y, width, height) to (min x, min y, max x, max y) format."""
196
+ ret = self.tlwh.copy()
197
+ ret[2:] += ret[:2]
198
+ return ret
199
+
200
+ @staticmethod
201
+ def tlwh_to_xyah(tlwh: np.ndarray) -> np.ndarray:
202
+ """Convert bounding box from tlwh format to center-x-center-y-aspect-height (xyah) format."""
203
+ ret = np.asarray(tlwh).copy()
204
+ ret[:2] += ret[2:] / 2
205
+ ret[2] /= ret[3]
206
+ return ret
207
+
208
+ @property
209
+ def xywh(self) -> np.ndarray:
210
+ """Get the current position of the bounding box in (center x, center y, width, height) format."""
211
+ ret = np.asarray(self.tlwh).copy()
212
+ ret[:2] += ret[2:] / 2
213
+ return ret
214
+
215
+ @property
216
+ def xywha(self) -> np.ndarray:
217
+ """Get position in (center x, center y, width, height, angle) format, warning if angle is missing."""
218
+ if self.angle is None:
219
+ return self.xywh
220
+ return np.concatenate([self.xywh, self.angle[None]])
221
+
222
+ @property
223
+ def result(self) -> List[float]:
224
+ """Get the current tracking results in the appropriate bounding box format."""
225
+ coords = self.xyxy if self.angle is None else self.xywha
226
+ return coords.tolist() + [self.track_id, self.score, self.cls, self.idx]
227
+
228
+ def __repr__(self) -> str:
229
+ """Return a string representation of the STrack object including start frame, end frame, and track ID."""
230
+ return f"OT_{self.track_id}_({self.start_frame}-{self.end_frame})"