matrice-analytics 0.1.60__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- matrice_analytics/__init__.py +28 -0
- matrice_analytics/boundary_drawing_internal/README.md +305 -0
- matrice_analytics/boundary_drawing_internal/__init__.py +45 -0
- matrice_analytics/boundary_drawing_internal/boundary_drawing_internal.py +1207 -0
- matrice_analytics/boundary_drawing_internal/boundary_drawing_tool.py +429 -0
- matrice_analytics/boundary_drawing_internal/boundary_tool_template.html +1036 -0
- matrice_analytics/boundary_drawing_internal/data/.gitignore +12 -0
- matrice_analytics/boundary_drawing_internal/example_usage.py +206 -0
- matrice_analytics/boundary_drawing_internal/usage/README.md +110 -0
- matrice_analytics/boundary_drawing_internal/usage/boundary_drawer_launcher.py +102 -0
- matrice_analytics/boundary_drawing_internal/usage/simple_boundary_launcher.py +107 -0
- matrice_analytics/post_processing/README.md +455 -0
- matrice_analytics/post_processing/__init__.py +732 -0
- matrice_analytics/post_processing/advanced_tracker/README.md +650 -0
- matrice_analytics/post_processing/advanced_tracker/__init__.py +17 -0
- matrice_analytics/post_processing/advanced_tracker/base.py +99 -0
- matrice_analytics/post_processing/advanced_tracker/config.py +77 -0
- matrice_analytics/post_processing/advanced_tracker/kalman_filter.py +370 -0
- matrice_analytics/post_processing/advanced_tracker/matching.py +195 -0
- matrice_analytics/post_processing/advanced_tracker/strack.py +230 -0
- matrice_analytics/post_processing/advanced_tracker/tracker.py +367 -0
- matrice_analytics/post_processing/config.py +146 -0
- matrice_analytics/post_processing/core/__init__.py +63 -0
- matrice_analytics/post_processing/core/base.py +704 -0
- matrice_analytics/post_processing/core/config.py +3291 -0
- matrice_analytics/post_processing/core/config_utils.py +925 -0
- matrice_analytics/post_processing/face_reg/__init__.py +43 -0
- matrice_analytics/post_processing/face_reg/compare_similarity.py +556 -0
- matrice_analytics/post_processing/face_reg/embedding_manager.py +950 -0
- matrice_analytics/post_processing/face_reg/face_recognition.py +2234 -0
- matrice_analytics/post_processing/face_reg/face_recognition_client.py +606 -0
- matrice_analytics/post_processing/face_reg/people_activity_logging.py +321 -0
- matrice_analytics/post_processing/ocr/__init__.py +0 -0
- matrice_analytics/post_processing/ocr/easyocr_extractor.py +250 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/__init__.py +9 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/__init__.py +4 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/cli.py +33 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/dataset_stats.py +139 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/export.py +398 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/train.py +447 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/utils.py +129 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/valid.py +93 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/validate_dataset.py +240 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/visualize_augmentation.py +176 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/cli/visualize_predictions.py +96 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/__init__.py +3 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/process.py +246 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/types.py +60 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/core/utils.py +87 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/__init__.py +3 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/config.py +82 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/hub.py +141 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/inference/plate_recognizer.py +323 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/py.typed +0 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/__init__.py +0 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/data/__init__.py +0 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/data/augmentation.py +101 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/data/dataset.py +97 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/__init__.py +0 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/config.py +114 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/layers.py +553 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/loss.py +55 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/metric.py +86 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/model_builders.py +95 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/model/model_schema.py +395 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/utilities/__init__.py +0 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/utilities/backend_utils.py +38 -0
- matrice_analytics/post_processing/ocr/fast_plate_ocr_py38/train/utilities/utils.py +214 -0
- matrice_analytics/post_processing/ocr/postprocessing.py +270 -0
- matrice_analytics/post_processing/ocr/preprocessing.py +52 -0
- matrice_analytics/post_processing/post_processor.py +1175 -0
- matrice_analytics/post_processing/test_cases/__init__.py +1 -0
- matrice_analytics/post_processing/test_cases/run_tests.py +143 -0
- matrice_analytics/post_processing/test_cases/test_advanced_customer_service.py +841 -0
- matrice_analytics/post_processing/test_cases/test_basic_counting_tracking.py +523 -0
- matrice_analytics/post_processing/test_cases/test_comprehensive.py +531 -0
- matrice_analytics/post_processing/test_cases/test_config.py +852 -0
- matrice_analytics/post_processing/test_cases/test_customer_service.py +585 -0
- matrice_analytics/post_processing/test_cases/test_data_generators.py +583 -0
- matrice_analytics/post_processing/test_cases/test_people_counting.py +510 -0
- matrice_analytics/post_processing/test_cases/test_processor.py +524 -0
- matrice_analytics/post_processing/test_cases/test_usecases.py +165 -0
- matrice_analytics/post_processing/test_cases/test_utilities.py +356 -0
- matrice_analytics/post_processing/test_cases/test_utils.py +743 -0
- matrice_analytics/post_processing/usecases/Histopathological_Cancer_Detection_img.py +604 -0
- matrice_analytics/post_processing/usecases/__init__.py +267 -0
- matrice_analytics/post_processing/usecases/abandoned_object_detection.py +797 -0
- matrice_analytics/post_processing/usecases/advanced_customer_service.py +1601 -0
- matrice_analytics/post_processing/usecases/age_detection.py +842 -0
- matrice_analytics/post_processing/usecases/age_gender_detection.py +1085 -0
- matrice_analytics/post_processing/usecases/anti_spoofing_detection.py +656 -0
- matrice_analytics/post_processing/usecases/assembly_line_detection.py +841 -0
- matrice_analytics/post_processing/usecases/banana_defect_detection.py +624 -0
- matrice_analytics/post_processing/usecases/basic_counting_tracking.py +667 -0
- matrice_analytics/post_processing/usecases/blood_cancer_detection_img.py +881 -0
- matrice_analytics/post_processing/usecases/car_damage_detection.py +834 -0
- matrice_analytics/post_processing/usecases/car_part_segmentation.py +946 -0
- matrice_analytics/post_processing/usecases/car_service.py +1601 -0
- matrice_analytics/post_processing/usecases/cardiomegaly_classification.py +864 -0
- matrice_analytics/post_processing/usecases/cell_microscopy_segmentation.py +897 -0
- matrice_analytics/post_processing/usecases/chicken_pose_detection.py +648 -0
- matrice_analytics/post_processing/usecases/child_monitoring.py +814 -0
- matrice_analytics/post_processing/usecases/color/clip.py +660 -0
- matrice_analytics/post_processing/usecases/color/clip_processor/merges.txt +48895 -0
- matrice_analytics/post_processing/usecases/color/clip_processor/preprocessor_config.json +28 -0
- matrice_analytics/post_processing/usecases/color/clip_processor/special_tokens_map.json +30 -0
- matrice_analytics/post_processing/usecases/color/clip_processor/tokenizer.json +245079 -0
- matrice_analytics/post_processing/usecases/color/clip_processor/tokenizer_config.json +32 -0
- matrice_analytics/post_processing/usecases/color/clip_processor/vocab.json +1 -0
- matrice_analytics/post_processing/usecases/color/color_map_utils.py +70 -0
- matrice_analytics/post_processing/usecases/color/color_mapper.py +468 -0
- matrice_analytics/post_processing/usecases/color_detection.py +1936 -0
- matrice_analytics/post_processing/usecases/color_map_utils.py +70 -0
- matrice_analytics/post_processing/usecases/concrete_crack_detection.py +827 -0
- matrice_analytics/post_processing/usecases/crop_weed_detection.py +781 -0
- matrice_analytics/post_processing/usecases/customer_service.py +1008 -0
- matrice_analytics/post_processing/usecases/defect_detection_products.py +936 -0
- matrice_analytics/post_processing/usecases/distracted_driver_detection.py +822 -0
- matrice_analytics/post_processing/usecases/drone_traffic_monitoring.py +585 -0
- matrice_analytics/post_processing/usecases/drowsy_driver_detection.py +829 -0
- matrice_analytics/post_processing/usecases/dwell_detection.py +829 -0
- matrice_analytics/post_processing/usecases/emergency_vehicle_detection.py +827 -0
- matrice_analytics/post_processing/usecases/face_emotion.py +813 -0
- matrice_analytics/post_processing/usecases/face_recognition.py +827 -0
- matrice_analytics/post_processing/usecases/fashion_detection.py +835 -0
- matrice_analytics/post_processing/usecases/field_mapping.py +902 -0
- matrice_analytics/post_processing/usecases/fire_detection.py +1146 -0
- matrice_analytics/post_processing/usecases/flare_analysis.py +836 -0
- matrice_analytics/post_processing/usecases/flower_segmentation.py +1006 -0
- matrice_analytics/post_processing/usecases/gas_leak_detection.py +837 -0
- matrice_analytics/post_processing/usecases/gender_detection.py +832 -0
- matrice_analytics/post_processing/usecases/human_activity_recognition.py +871 -0
- matrice_analytics/post_processing/usecases/intrusion_detection.py +1672 -0
- matrice_analytics/post_processing/usecases/leaf.py +821 -0
- matrice_analytics/post_processing/usecases/leaf_disease.py +840 -0
- matrice_analytics/post_processing/usecases/leak_detection.py +837 -0
- matrice_analytics/post_processing/usecases/license_plate_detection.py +1188 -0
- matrice_analytics/post_processing/usecases/license_plate_monitoring.py +1781 -0
- matrice_analytics/post_processing/usecases/litter_monitoring.py +717 -0
- matrice_analytics/post_processing/usecases/mask_detection.py +869 -0
- matrice_analytics/post_processing/usecases/natural_disaster.py +907 -0
- matrice_analytics/post_processing/usecases/parking.py +787 -0
- matrice_analytics/post_processing/usecases/parking_space_detection.py +822 -0
- matrice_analytics/post_processing/usecases/pcb_defect_detection.py +888 -0
- matrice_analytics/post_processing/usecases/pedestrian_detection.py +808 -0
- matrice_analytics/post_processing/usecases/people_counting.py +706 -0
- matrice_analytics/post_processing/usecases/people_counting_bckp.py +1683 -0
- matrice_analytics/post_processing/usecases/people_tracking.py +1842 -0
- matrice_analytics/post_processing/usecases/pipeline_detection.py +605 -0
- matrice_analytics/post_processing/usecases/plaque_segmentation_img.py +874 -0
- matrice_analytics/post_processing/usecases/pothole_segmentation.py +915 -0
- matrice_analytics/post_processing/usecases/ppe_compliance.py +645 -0
- matrice_analytics/post_processing/usecases/price_tag_detection.py +822 -0
- matrice_analytics/post_processing/usecases/proximity_detection.py +1901 -0
- matrice_analytics/post_processing/usecases/road_lane_detection.py +623 -0
- matrice_analytics/post_processing/usecases/road_traffic_density.py +832 -0
- matrice_analytics/post_processing/usecases/road_view_segmentation.py +915 -0
- matrice_analytics/post_processing/usecases/shelf_inventory_detection.py +583 -0
- matrice_analytics/post_processing/usecases/shoplifting_detection.py +822 -0
- matrice_analytics/post_processing/usecases/shopping_cart_analysis.py +899 -0
- matrice_analytics/post_processing/usecases/skin_cancer_classification_img.py +864 -0
- matrice_analytics/post_processing/usecases/smoker_detection.py +833 -0
- matrice_analytics/post_processing/usecases/solar_panel.py +810 -0
- matrice_analytics/post_processing/usecases/suspicious_activity_detection.py +1030 -0
- matrice_analytics/post_processing/usecases/template_usecase.py +380 -0
- matrice_analytics/post_processing/usecases/theft_detection.py +648 -0
- matrice_analytics/post_processing/usecases/traffic_sign_monitoring.py +724 -0
- matrice_analytics/post_processing/usecases/underground_pipeline_defect_detection.py +775 -0
- matrice_analytics/post_processing/usecases/underwater_pollution_detection.py +842 -0
- matrice_analytics/post_processing/usecases/vehicle_monitoring.py +1029 -0
- matrice_analytics/post_processing/usecases/warehouse_object_segmentation.py +899 -0
- matrice_analytics/post_processing/usecases/waterbody_segmentation.py +923 -0
- matrice_analytics/post_processing/usecases/weapon_detection.py +771 -0
- matrice_analytics/post_processing/usecases/weld_defect_detection.py +615 -0
- matrice_analytics/post_processing/usecases/wildlife_monitoring.py +898 -0
- matrice_analytics/post_processing/usecases/windmill_maintenance.py +834 -0
- matrice_analytics/post_processing/usecases/wound_segmentation.py +856 -0
- matrice_analytics/post_processing/utils/__init__.py +150 -0
- matrice_analytics/post_processing/utils/advanced_counting_utils.py +400 -0
- matrice_analytics/post_processing/utils/advanced_helper_utils.py +317 -0
- matrice_analytics/post_processing/utils/advanced_tracking_utils.py +461 -0
- matrice_analytics/post_processing/utils/alerting_utils.py +213 -0
- matrice_analytics/post_processing/utils/category_mapping_utils.py +94 -0
- matrice_analytics/post_processing/utils/color_utils.py +592 -0
- matrice_analytics/post_processing/utils/counting_utils.py +182 -0
- matrice_analytics/post_processing/utils/filter_utils.py +261 -0
- matrice_analytics/post_processing/utils/format_utils.py +293 -0
- matrice_analytics/post_processing/utils/geometry_utils.py +300 -0
- matrice_analytics/post_processing/utils/smoothing_utils.py +358 -0
- matrice_analytics/post_processing/utils/tracking_utils.py +234 -0
- matrice_analytics/py.typed +0 -0
- matrice_analytics-0.1.60.dist-info/METADATA +481 -0
- matrice_analytics-0.1.60.dist-info/RECORD +196 -0
- matrice_analytics-0.1.60.dist-info/WHEEL +5 -0
- matrice_analytics-0.1.60.dist-info/licenses/LICENSE.txt +21 -0
- matrice_analytics-0.1.60.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Matching utilities for track association.
|
|
3
|
+
|
|
4
|
+
This module provides utilities for matching tracks with detections,
|
|
5
|
+
including IoU distance calculation and linear assignment.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import scipy
|
|
10
|
+
from scipy.spatial.distance import cdist
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
import lap # for linear_assignment
|
|
14
|
+
assert lap.__version__ # verify package is not directory
|
|
15
|
+
except (ImportError, AssertionError, AttributeError):
|
|
16
|
+
# Fallback to scipy if lap is not available
|
|
17
|
+
lap = None
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = True):
|
|
21
|
+
"""
|
|
22
|
+
Perform linear assignment using either the scipy or lap.lapjv method.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M).
|
|
26
|
+
thresh (float): Threshold for considering an assignment valid.
|
|
27
|
+
use_lap (bool): Use lap.lapjv for the assignment. If False, scipy.optimize.linear_sum_assignment is used.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
matched_indices (np.ndarray): Array of matched indices of shape (K, 2), where K is the number of matches.
|
|
31
|
+
unmatched_a (np.ndarray): Array of unmatched indices from the first set, with shape (L,).
|
|
32
|
+
unmatched_b (np.ndarray): Array of unmatched indices from the second set, with shape (M,).
|
|
33
|
+
"""
|
|
34
|
+
if cost_matrix.size == 0:
|
|
35
|
+
return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1]))
|
|
36
|
+
|
|
37
|
+
if use_lap and lap is not None:
|
|
38
|
+
# Use lap.lapjv
|
|
39
|
+
_, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)
|
|
40
|
+
matches = [[ix, mx] for ix, mx in enumerate(x) if mx >= 0]
|
|
41
|
+
unmatched_a = np.where(x < 0)[0]
|
|
42
|
+
unmatched_b = np.where(y < 0)[0]
|
|
43
|
+
else:
|
|
44
|
+
# Use scipy.optimize.linear_sum_assignment
|
|
45
|
+
x, y = scipy.optimize.linear_sum_assignment(cost_matrix) # row x, col y
|
|
46
|
+
matches = np.asarray([[x[i], y[i]] for i in range(len(x)) if cost_matrix[x[i], y[i]] <= thresh])
|
|
47
|
+
if len(matches) == 0:
|
|
48
|
+
unmatched_a = list(np.arange(cost_matrix.shape[0]))
|
|
49
|
+
unmatched_b = list(np.arange(cost_matrix.shape[1]))
|
|
50
|
+
else:
|
|
51
|
+
unmatched_a = list(frozenset(np.arange(cost_matrix.shape[0])) - frozenset(matches[:, 0]))
|
|
52
|
+
unmatched_b = list(frozenset(np.arange(cost_matrix.shape[1])) - frozenset(matches[:, 1]))
|
|
53
|
+
|
|
54
|
+
return matches, unmatched_a, unmatched_b
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def bbox_ioa(box1: np.ndarray, box2: np.ndarray, iou: bool = True) -> np.ndarray:
|
|
58
|
+
"""
|
|
59
|
+
Calculate the intersection over area of box1, box2. Boxes are in x1y1x2y2 format.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
box1 (np.ndarray): First set of boxes (N, 4)
|
|
63
|
+
box2 (np.ndarray): Second set of boxes (M, 4)
|
|
64
|
+
iou (bool): If True, calculate IoU, otherwise calculate IoA
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
np.ndarray: IoU/IoA matrix of shape (N, M)
|
|
68
|
+
"""
|
|
69
|
+
# Returns the intersection over box1 area by default
|
|
70
|
+
# box1: (N, 4), box2: (M, 4)
|
|
71
|
+
N = box1.shape[0]
|
|
72
|
+
M = box2.shape[0]
|
|
73
|
+
|
|
74
|
+
# Calculate intersection
|
|
75
|
+
tl = np.maximum(box1[:, None, :2], box2[:, :2]) # (N, M, 2)
|
|
76
|
+
br = np.minimum(box1[:, None, 2:], box2[:, 2:]) # (N, M, 2)
|
|
77
|
+
wh = np.maximum(0, br - tl) # (N, M, 2)
|
|
78
|
+
inter = wh[:, :, 0] * wh[:, :, 1] # (N, M)
|
|
79
|
+
|
|
80
|
+
# Calculate areas
|
|
81
|
+
area1 = (box1[:, 2] - box1[:, 0]) * (box1[:, 3] - box1[:, 1]) # (N,)
|
|
82
|
+
area2 = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1]) # (M,)
|
|
83
|
+
|
|
84
|
+
if iou:
|
|
85
|
+
union = area1[:, None] + area2 - inter
|
|
86
|
+
return inter / union
|
|
87
|
+
else:
|
|
88
|
+
return inter / area1[:, None]
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def iou_distance(atracks: list, btracks: list) -> np.ndarray:
|
|
92
|
+
"""
|
|
93
|
+
Compute cost based on Intersection over Union (IoU) between tracks.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
atracks (List[STrack] or List[np.ndarray]): List of tracks 'a' or bounding boxes.
|
|
97
|
+
btracks (List[STrack] or List[np.ndarray]): List of tracks 'b' or bounding boxes.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
(np.ndarray): Cost matrix computed based on IoU with shape (len(atracks), len(btracks)).
|
|
101
|
+
"""
|
|
102
|
+
if atracks and isinstance(atracks[0], np.ndarray) or btracks and isinstance(btracks[0], np.ndarray):
|
|
103
|
+
atlbrs = atracks
|
|
104
|
+
btlbrs = btracks
|
|
105
|
+
else:
|
|
106
|
+
# Extract bounding boxes from track objects
|
|
107
|
+
atlbrs = []
|
|
108
|
+
btlbrs = []
|
|
109
|
+
for track in atracks:
|
|
110
|
+
if hasattr(track, 'xyxy'):
|
|
111
|
+
atlbrs.append(track.xyxy)
|
|
112
|
+
elif hasattr(track, 'tlwh'):
|
|
113
|
+
tlwh = track.tlwh
|
|
114
|
+
atlbrs.append([tlwh[0], tlwh[1], tlwh[0] + tlwh[2], tlwh[1] + tlwh[3]])
|
|
115
|
+
else:
|
|
116
|
+
atlbrs.append([0, 0, 0, 0])
|
|
117
|
+
|
|
118
|
+
for track in btracks:
|
|
119
|
+
if hasattr(track, 'xyxy'):
|
|
120
|
+
btlbrs.append(track.xyxy)
|
|
121
|
+
elif hasattr(track, 'tlwh'):
|
|
122
|
+
tlwh = track.tlwh
|
|
123
|
+
btlbrs.append([tlwh[0], tlwh[1], tlwh[0] + tlwh[2], tlwh[1] + tlwh[3]])
|
|
124
|
+
else:
|
|
125
|
+
btlbrs.append([0, 0, 0, 0])
|
|
126
|
+
|
|
127
|
+
ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32)
|
|
128
|
+
if len(atlbrs) and len(btlbrs):
|
|
129
|
+
ious = bbox_ioa(
|
|
130
|
+
np.ascontiguousarray(atlbrs, dtype=np.float32),
|
|
131
|
+
np.ascontiguousarray(btlbrs, dtype=np.float32),
|
|
132
|
+
iou=True,
|
|
133
|
+
)
|
|
134
|
+
return 1 - ious # cost matrix
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def embedding_distance(tracks: list, detections: list, metric: str = "cosine") -> np.ndarray:
|
|
138
|
+
"""
|
|
139
|
+
Compute distance between tracks and detections based on embeddings.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
tracks (List[STrack] or List[np.ndarray]): List of tracks, where each track contains embedding features.
|
|
143
|
+
detections (List[BaseTrack]): List of detections, where each detection contains embedding features.
|
|
144
|
+
metric (str): Metric for distance computation. Supported metrics include 'cosine', 'euclidean', etc.
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
(np.ndarray): Cost matrix computed based on embeddings with shape (N, M), where N is the number of tracks
|
|
148
|
+
and M is the number of detections.
|
|
149
|
+
"""
|
|
150
|
+
cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32)
|
|
151
|
+
if cost_matrix.size == 0:
|
|
152
|
+
return cost_matrix
|
|
153
|
+
|
|
154
|
+
# Extract features from detections
|
|
155
|
+
det_features = []
|
|
156
|
+
for det in detections:
|
|
157
|
+
if hasattr(det, 'curr_feat') and det.curr_feat is not None:
|
|
158
|
+
det_features.append(det.curr_feat)
|
|
159
|
+
else:
|
|
160
|
+
det_features.append(np.zeros(128)) # Default feature size
|
|
161
|
+
|
|
162
|
+
det_features = np.asarray(det_features, dtype=np.float32)
|
|
163
|
+
|
|
164
|
+
# Extract features from tracks
|
|
165
|
+
track_features = []
|
|
166
|
+
for track in tracks:
|
|
167
|
+
if hasattr(track, 'smooth_feat') and track.smooth_feat is not None:
|
|
168
|
+
track_features.append(track.smooth_feat)
|
|
169
|
+
else:
|
|
170
|
+
track_features.append(np.zeros(128)) # Default feature size
|
|
171
|
+
|
|
172
|
+
track_features = np.asarray(track_features, dtype=np.float32)
|
|
173
|
+
|
|
174
|
+
cost_matrix = np.maximum(0.0, cdist(track_features, det_features, metric))
|
|
175
|
+
return cost_matrix
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def fuse_score(cost_matrix: np.ndarray, detections: list) -> np.ndarray:
|
|
179
|
+
"""
|
|
180
|
+
Fuse cost matrix with detection scores to produce a single similarity matrix.
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M).
|
|
184
|
+
detections (List[BaseTrack]): List of detections, each containing a score attribute.
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
(np.ndarray): Fused similarity matrix with shape (N, M).
|
|
188
|
+
"""
|
|
189
|
+
if cost_matrix.size == 0:
|
|
190
|
+
return cost_matrix
|
|
191
|
+
iou_sim = 1 - cost_matrix
|
|
192
|
+
det_scores = np.array([det.score for det in detections])
|
|
193
|
+
det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0)
|
|
194
|
+
fuse_sim = iou_sim * det_scores
|
|
195
|
+
return 1 - fuse_sim # fuse_cost
|
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
"""
|
|
2
|
+
STrack class for single object tracking.
|
|
3
|
+
|
|
4
|
+
This module provides the STrack class that represents a single tracked object
|
|
5
|
+
with Kalman filtering for state estimation.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Any, List
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
from .base import BaseTrack, TrackState
|
|
12
|
+
from .kalman_filter import KalmanFilterXYAH
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def xywh2ltwh(xywh: List[float]) -> List[float]:
|
|
16
|
+
"""
|
|
17
|
+
Convert bounding box from center format (x, y, w, h) to top-left format (x, y, w, h).
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
xywh (List[float]): Bounding box in center format [x, y, w, h]
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
List[float]: Bounding box in top-left format [x, y, w, h]
|
|
24
|
+
"""
|
|
25
|
+
x, y, w, h = xywh
|
|
26
|
+
return [x - w/2, y - h/2, w, h]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class STrack(BaseTrack):
|
|
30
|
+
"""
|
|
31
|
+
Single object tracking representation that uses Kalman filtering for state estimation.
|
|
32
|
+
|
|
33
|
+
This class is responsible for storing all the information regarding individual tracklets and performs state updates
|
|
34
|
+
and predictions based on Kalman filter.
|
|
35
|
+
|
|
36
|
+
Attributes:
|
|
37
|
+
shared_kalman (KalmanFilterXYAH): Shared Kalman filter used across all STrack instances for prediction.
|
|
38
|
+
_tlwh (np.ndarray): Private attribute to store top-left corner coordinates and width and height of bounding box.
|
|
39
|
+
kalman_filter (KalmanFilterXYAH): Instance of Kalman filter used for this particular object track.
|
|
40
|
+
mean (np.ndarray): Mean state estimate vector.
|
|
41
|
+
covariance (np.ndarray): Covariance of state estimate.
|
|
42
|
+
is_activated (bool): Boolean flag indicating if the track has been activated.
|
|
43
|
+
score (float): Confidence score of the track.
|
|
44
|
+
tracklet_len (int): Length of the tracklet.
|
|
45
|
+
cls (Any): Class label for the object.
|
|
46
|
+
idx (int): Index or identifier for the object.
|
|
47
|
+
frame_id (int): Current frame ID.
|
|
48
|
+
start_frame (int): Frame where the object was first detected.
|
|
49
|
+
angle (float or None): Optional angle information for oriented bounding boxes.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
shared_kalman = KalmanFilterXYAH()
|
|
53
|
+
|
|
54
|
+
def __init__(self, xywh: List[float], score: float, cls: Any):
|
|
55
|
+
"""
|
|
56
|
+
Initialize a new STrack instance.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
xywh (List[float]): Bounding box coordinates and dimensions in the format (x, y, w, h, [a], idx), where
|
|
60
|
+
(x, y) is the center, (w, h) are width and height, [a] is optional aspect ratio, and idx is the id.
|
|
61
|
+
score (float): Confidence score of the detection.
|
|
62
|
+
cls (Any): Class label for the detected object.
|
|
63
|
+
"""
|
|
64
|
+
super().__init__()
|
|
65
|
+
# xywh+idx or xywha+idx
|
|
66
|
+
assert len(xywh) in {5, 6}, f"expected 5 or 6 values but got {len(xywh)}"
|
|
67
|
+
self._tlwh = np.asarray(xywh2ltwh(xywh[:4]), dtype=np.float32)
|
|
68
|
+
self.kalman_filter = None
|
|
69
|
+
self.mean, self.covariance = None, None
|
|
70
|
+
self.is_activated = False
|
|
71
|
+
|
|
72
|
+
self.score = score
|
|
73
|
+
self.tracklet_len = 0
|
|
74
|
+
self.cls = cls
|
|
75
|
+
self.idx = xywh[-1]
|
|
76
|
+
self.angle = xywh[4] if len(xywh) == 6 else None
|
|
77
|
+
|
|
78
|
+
def predict(self):
|
|
79
|
+
"""Predict the next state (mean and covariance) of the object using the Kalman filter."""
|
|
80
|
+
mean_state = self.mean.copy()
|
|
81
|
+
if self.state != TrackState.Tracked:
|
|
82
|
+
mean_state[7] = 0
|
|
83
|
+
self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
|
|
84
|
+
|
|
85
|
+
@staticmethod
|
|
86
|
+
def multi_predict(stracks: List["STrack"]):
|
|
87
|
+
"""Perform multi-object predictive tracking using Kalman filter for the provided list of STrack instances."""
|
|
88
|
+
if len(stracks) <= 0:
|
|
89
|
+
return
|
|
90
|
+
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
|
91
|
+
multi_covariance = np.asarray([st.covariance for st in stracks])
|
|
92
|
+
for i, st in enumerate(stracks):
|
|
93
|
+
if st.state != TrackState.Tracked:
|
|
94
|
+
multi_mean[i][7] = 0
|
|
95
|
+
multi_mean, multi_covariance = STrack.shared_kalman.multi_predict(multi_mean, multi_covariance)
|
|
96
|
+
for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
|
|
97
|
+
stracks[i].mean = mean
|
|
98
|
+
stracks[i].covariance = cov
|
|
99
|
+
|
|
100
|
+
@staticmethod
|
|
101
|
+
def multi_gmc(stracks: List["STrack"], H: np.ndarray = np.eye(2, 3)):
|
|
102
|
+
"""Update state tracks positions and covariances using a homography matrix for multiple tracks."""
|
|
103
|
+
if len(stracks) > 0:
|
|
104
|
+
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
|
105
|
+
multi_covariance = np.asarray([st.covariance for st in stracks])
|
|
106
|
+
|
|
107
|
+
R = H[:2, :2]
|
|
108
|
+
R8x8 = np.kron(np.eye(4, dtype=float), R)
|
|
109
|
+
t = H[:2, 2]
|
|
110
|
+
|
|
111
|
+
for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
|
|
112
|
+
mean = R8x8.dot(mean)
|
|
113
|
+
mean[:2] += t
|
|
114
|
+
cov = R8x8.dot(cov).dot(R8x8.transpose())
|
|
115
|
+
|
|
116
|
+
stracks[i].mean = mean
|
|
117
|
+
stracks[i].covariance = cov
|
|
118
|
+
|
|
119
|
+
def activate(self, kalman_filter: KalmanFilterXYAH, frame_id: int):
|
|
120
|
+
"""Activate a new tracklet using the provided Kalman filter and initialize its state and covariance."""
|
|
121
|
+
self.kalman_filter = kalman_filter
|
|
122
|
+
self.track_id = self.next_id()
|
|
123
|
+
self.mean, self.covariance = self.kalman_filter.initiate(self.convert_coords(self._tlwh))
|
|
124
|
+
|
|
125
|
+
self.tracklet_len = 0
|
|
126
|
+
self.state = TrackState.Tracked
|
|
127
|
+
if frame_id == 1:
|
|
128
|
+
self.is_activated = True
|
|
129
|
+
self.frame_id = frame_id
|
|
130
|
+
self.start_frame = frame_id
|
|
131
|
+
|
|
132
|
+
def re_activate(self, new_track: "STrack", frame_id: int, new_id: bool = False):
|
|
133
|
+
"""Reactivate a previously lost track using new detection data and update its state and attributes."""
|
|
134
|
+
self.mean, self.covariance = self.kalman_filter.update(
|
|
135
|
+
self.mean, self.covariance, self.convert_coords(new_track.tlwh)
|
|
136
|
+
)
|
|
137
|
+
self.tracklet_len = 0
|
|
138
|
+
self.state = TrackState.Tracked
|
|
139
|
+
self.is_activated = True
|
|
140
|
+
self.frame_id = frame_id
|
|
141
|
+
if new_id:
|
|
142
|
+
self.track_id = self.next_id()
|
|
143
|
+
self.score = new_track.score
|
|
144
|
+
self.cls = new_track.cls
|
|
145
|
+
self.angle = new_track.angle
|
|
146
|
+
self.idx = new_track.idx
|
|
147
|
+
|
|
148
|
+
# CRITICAL FIX: Preserve original detection data for face recognition fields
|
|
149
|
+
if hasattr(new_track, 'original_detection'):
|
|
150
|
+
self.original_detection = new_track.original_detection
|
|
151
|
+
|
|
152
|
+
def update(self, new_track: "STrack", frame_id: int):
|
|
153
|
+
"""
|
|
154
|
+
Update the state of a matched track.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
new_track (STrack): The new track containing updated information.
|
|
158
|
+
frame_id (int): The ID of the current frame.
|
|
159
|
+
"""
|
|
160
|
+
self.frame_id = frame_id
|
|
161
|
+
self.tracklet_len += 1
|
|
162
|
+
|
|
163
|
+
new_tlwh = new_track.tlwh
|
|
164
|
+
self.mean, self.covariance = self.kalman_filter.update(
|
|
165
|
+
self.mean, self.covariance, self.convert_coords(new_tlwh)
|
|
166
|
+
)
|
|
167
|
+
self.state = TrackState.Tracked
|
|
168
|
+
self.is_activated = True
|
|
169
|
+
|
|
170
|
+
self.score = new_track.score
|
|
171
|
+
self.cls = new_track.cls
|
|
172
|
+
self.angle = new_track.angle
|
|
173
|
+
self.idx = new_track.idx
|
|
174
|
+
|
|
175
|
+
# CRITICAL FIX: Preserve original detection data for face recognition fields
|
|
176
|
+
if hasattr(new_track, 'original_detection'):
|
|
177
|
+
self.original_detection = new_track.original_detection
|
|
178
|
+
|
|
179
|
+
def convert_coords(self, tlwh: np.ndarray) -> np.ndarray:
|
|
180
|
+
"""Convert a bounding box's top-left-width-height format to its x-y-aspect-height equivalent."""
|
|
181
|
+
return self.tlwh_to_xyah(tlwh)
|
|
182
|
+
|
|
183
|
+
@property
|
|
184
|
+
def tlwh(self) -> np.ndarray:
|
|
185
|
+
"""Get the bounding box in top-left-width-height format from the current state estimate."""
|
|
186
|
+
if self.mean is None:
|
|
187
|
+
return self._tlwh.copy()
|
|
188
|
+
ret = self.mean[:4].copy()
|
|
189
|
+
ret[2] *= ret[3]
|
|
190
|
+
ret[:2] -= ret[2:] / 2
|
|
191
|
+
return ret
|
|
192
|
+
|
|
193
|
+
@property
|
|
194
|
+
def xyxy(self) -> np.ndarray:
|
|
195
|
+
"""Convert bounding box from (top left x, top left y, width, height) to (min x, min y, max x, max y) format."""
|
|
196
|
+
ret = self.tlwh.copy()
|
|
197
|
+
ret[2:] += ret[:2]
|
|
198
|
+
return ret
|
|
199
|
+
|
|
200
|
+
@staticmethod
|
|
201
|
+
def tlwh_to_xyah(tlwh: np.ndarray) -> np.ndarray:
|
|
202
|
+
"""Convert bounding box from tlwh format to center-x-center-y-aspect-height (xyah) format."""
|
|
203
|
+
ret = np.asarray(tlwh).copy()
|
|
204
|
+
ret[:2] += ret[2:] / 2
|
|
205
|
+
ret[2] /= ret[3]
|
|
206
|
+
return ret
|
|
207
|
+
|
|
208
|
+
@property
|
|
209
|
+
def xywh(self) -> np.ndarray:
|
|
210
|
+
"""Get the current position of the bounding box in (center x, center y, width, height) format."""
|
|
211
|
+
ret = np.asarray(self.tlwh).copy()
|
|
212
|
+
ret[:2] += ret[2:] / 2
|
|
213
|
+
return ret
|
|
214
|
+
|
|
215
|
+
@property
|
|
216
|
+
def xywha(self) -> np.ndarray:
|
|
217
|
+
"""Get position in (center x, center y, width, height, angle) format, warning if angle is missing."""
|
|
218
|
+
if self.angle is None:
|
|
219
|
+
return self.xywh
|
|
220
|
+
return np.concatenate([self.xywh, self.angle[None]])
|
|
221
|
+
|
|
222
|
+
@property
|
|
223
|
+
def result(self) -> List[float]:
|
|
224
|
+
"""Get the current tracking results in the appropriate bounding box format."""
|
|
225
|
+
coords = self.xyxy if self.angle is None else self.xywha
|
|
226
|
+
return coords.tolist() + [self.track_id, self.score, self.cls, self.idx]
|
|
227
|
+
|
|
228
|
+
def __repr__(self) -> str:
|
|
229
|
+
"""Return a string representation of the STrack object including start frame, end frame, and track ID."""
|
|
230
|
+
return f"OT_{self.track_id}_({self.start_frame}-{self.end_frame})"
|