dnt 0.3.1.3__py3-none-any.whl → 0.3.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dnt might be problematic. Click here for more details.
- dnt/__init__.py +1 -1
- dnt/analysis/__init__.py +3 -3
- dnt/analysis/count.py +54 -37
- dnt/analysis/{interaction.py → interaction2.py} +23 -8
- dnt/analysis/stop3.py +7 -3
- dnt/detect/signal/detector.py +13 -4
- dnt/detect/timestamp.py +105 -0
- dnt/detect/yolov8/detector.py +72 -29
- dnt/detect/yolov8/segmentor.py +60 -2
- dnt/filter/filter.py +19 -8
- dnt/label/labeler2.py +170 -67
- dnt/shared/synhcro.py +1 -1
- dnt/track/dsort/deep_sort/deep_sort.py +4 -3
- dnt/track/dsort/deep_sort/sort/detection.py +2 -1
- dnt/track/dsort/deep_sort/sort/track.py +2 -1
- dnt/track/dsort/deep_sort/sort/tracker.py +1 -1
- dnt/track/dsort/dsort.py +34 -17
- dnt/track/re_class.py +29 -6
- dnt/track/sort/sort.py +4 -5
- dnt/track/tracker.py +9 -5
- {dnt-0.3.1.3.dist-info → dnt-0.3.1.7.dist-info}/METADATA +16 -8
- {dnt-0.3.1.3.dist-info → dnt-0.3.1.7.dist-info}/RECORD +25 -24
- {dnt-0.3.1.3.dist-info → dnt-0.3.1.7.dist-info}/WHEEL +1 -1
- {dnt-0.3.1.3.dist-info → dnt-0.3.1.7.dist-info/licenses}/LICENSE +0 -0
- {dnt-0.3.1.3.dist-info → dnt-0.3.1.7.dist-info}/top_level.txt +0 -0
dnt/__init__.py
CHANGED
dnt/analysis/__init__.py
CHANGED
dnt/analysis/count.py
CHANGED
|
@@ -1,54 +1,73 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
1
|
+
# pylint: disable=too-many-arguments
|
|
2
|
+
'''
|
|
3
|
+
This script is used to count the number of tracks that passing a reference line for post-analysis
|
|
4
|
+
last modified: 2021-09-30
|
|
5
|
+
'''
|
|
6
|
+
from shapely.geometry import Polygon, LineString
|
|
7
|
+
import geopandas as gpd
|
|
8
|
+
import pandas as pd
|
|
4
9
|
from tqdm import tqdm
|
|
5
10
|
import numpy as np
|
|
6
11
|
|
|
7
12
|
class Count:
|
|
8
|
-
|
|
9
|
-
|
|
13
|
+
'''
|
|
14
|
+
Count tracks that passing a reference line for post-analysis
|
|
15
|
+
Methods:
|
|
16
|
+
__init__(): initialize the class
|
|
17
|
+
count_tracks_by_line():
|
|
18
|
+
count tracks that passing a reference line for post-analysis
|
|
19
|
+
'''
|
|
20
|
+
def __init__(self) -> None:
|
|
21
|
+
self.track_fields = ['frame', 'track_id', 'x', 'y', 'w', 'h', 'r1', 'r2', 'r3', 'r4']
|
|
10
22
|
|
|
11
|
-
def
|
|
12
|
-
|
|
13
|
-
|
|
23
|
+
def count_tracks_by_line(self,
|
|
24
|
+
tracks:pd.DataFrame=None,
|
|
25
|
+
track_file:str=None,
|
|
26
|
+
line: LineString=None,
|
|
27
|
+
video_index:int=None,
|
|
28
|
+
video_tot:int=None,
|
|
29
|
+
verbose:bool=True) -> pd.DataFrame:
|
|
14
30
|
'''
|
|
15
31
|
Count tracks that passing a reference line for post-analysis
|
|
16
32
|
Inputs:
|
|
17
|
-
tracks - a DataFrame of tracks,
|
|
18
|
-
if None, read track_file
|
|
33
|
+
tracks - a DataFrame of tracks, if None (default), read track_file
|
|
19
34
|
track_file - a txt file contains tracks,
|
|
20
|
-
line - the reference line
|
|
35
|
+
line - the reference line, LineString(pointA, pointB)
|
|
21
36
|
video_index - the index of video for processing
|
|
22
37
|
video_tot - the total number of videos
|
|
38
|
+
verbose - if True, show progress bar
|
|
23
39
|
Return:
|
|
24
40
|
A DataFrame of track_id, frame, direction, count
|
|
25
41
|
'''
|
|
26
|
-
|
|
42
|
+
# Load tracks
|
|
27
43
|
if tracks is None:
|
|
28
44
|
tracks = pd.read_csv(track_file, header=None)
|
|
29
|
-
|
|
30
45
|
if len(tracks.columns) != len(self.track_fields):
|
|
31
46
|
raise Exception("The tracks format is incorrect!")
|
|
32
|
-
|
|
33
47
|
tracks.columns = self.track_fields
|
|
34
48
|
track_ids = tracks['track_id'].unique()
|
|
35
49
|
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
pbar.set_description_str("Counting {} of {}".format(video_index, video_tot))
|
|
46
|
-
else:
|
|
47
|
-
pbar.set_description_str("Counting")
|
|
50
|
+
# Create a GeoDataFrame of tracks
|
|
51
|
+
geo = tracks.apply(lambda track: Polygon([(track['x'], track['y']),
|
|
52
|
+
(track['x'] + track['w'], track['y']),
|
|
53
|
+
(track['x'] + track['w'], track['y'] + track['h']),
|
|
54
|
+
(track['x'], track['y'] + track['h'])]),
|
|
55
|
+
axis =1)
|
|
56
|
+
geo_tracks = gpd.GeoDataFrame(tracks, geometry=geo)
|
|
57
|
+
point_a = np.array(line.coords[0])
|
|
58
|
+
point_b = np.array(line.coords[1])
|
|
48
59
|
|
|
60
|
+
# Interate through all tracks
|
|
61
|
+
if verbose:
|
|
62
|
+
pbar = tqdm(total=len(track_ids), unit=' tracks')
|
|
63
|
+
if video_index and video_tot:
|
|
64
|
+
pbar.set_description_str("Counting {} of {}".format(video_index, video_tot))
|
|
65
|
+
else:
|
|
66
|
+
pbar.set_description_str("Counting")
|
|
49
67
|
intersected_tracks = []
|
|
50
68
|
intersected_frames = []
|
|
51
|
-
intersected_direct = []
|
|
69
|
+
intersected_direct = []
|
|
70
|
+
|
|
52
71
|
for track_id in track_ids:
|
|
53
72
|
selected = geo_tracks.loc[(geo_tracks['track_id']==track_id)].copy()
|
|
54
73
|
if len(selected)>0:
|
|
@@ -60,19 +79,19 @@ class Count:
|
|
|
60
79
|
intersected_tracks.append(track_id)
|
|
61
80
|
frame_pos = int(len(intersected)/2)
|
|
62
81
|
intersected_frames.append(intersected['frame'].values[frame_pos])
|
|
63
|
-
|
|
82
|
+
|
|
64
83
|
# center point of the first frame
|
|
65
|
-
|
|
84
|
+
point_c = np.array((intersected.iloc[0]['x'] + intersected.iloc[0]['w']/2,
|
|
66
85
|
intersected.iloc[0]['y'] + intersected.iloc[0]['h']/2))
|
|
67
86
|
d = 2 # right
|
|
68
|
-
if np.cross(
|
|
87
|
+
if np.cross(point_c-point_a, point_b-point_a) < 0:
|
|
69
88
|
d = 1 # left
|
|
70
|
-
|
|
71
89
|
intersected_direct.append(d)
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
90
|
+
if verbose:
|
|
91
|
+
pbar.update()
|
|
92
|
+
|
|
93
|
+
if verbose:
|
|
94
|
+
pbar.close()
|
|
76
95
|
|
|
77
96
|
results = pd.DataFrame(
|
|
78
97
|
{
|
|
@@ -86,5 +105,3 @@ class Count:
|
|
|
86
105
|
results['count'] = results.index + 1
|
|
87
106
|
|
|
88
107
|
return results
|
|
89
|
-
|
|
90
|
-
|
|
@@ -1,21 +1,36 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module is used to analyze the interaction between two tracks.
|
|
3
|
+
Author: Zhenyu Wang
|
|
4
|
+
Date: 2023/10/20
|
|
5
|
+
Email: wonstran@hotmail.com
|
|
6
|
+
Updates:
|
|
7
|
+
2025/03/19: Add the function of nearmiss based on PET
|
|
8
|
+
2023/10/20: Initial version
|
|
9
|
+
"""
|
|
1
10
|
from shapely.geometry import Point, Polygon, LineString, box
|
|
2
11
|
from shapely import intersection, distance, intersects
|
|
3
12
|
from shapelysmooth import taubin_smooth, chaikin_smooth, catmull_rom_smooth
|
|
4
13
|
import geopandas as gpd, pandas as pd
|
|
5
14
|
from tqdm import tqdm
|
|
6
15
|
import numpy as np
|
|
7
|
-
from matplotlib import pyplot as plt
|
|
8
16
|
from dnt.label.labeler2 import Labeler
|
|
9
17
|
import os
|
|
10
|
-
from ..track import Tracker
|
|
11
18
|
|
|
12
19
|
class YieldAnalyzer:
|
|
13
|
-
def __init__(self,
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
20
|
+
def __init__(self,
|
|
21
|
+
waiting_dist_p:int=600,
|
|
22
|
+
waiting_dist_y:int=300,
|
|
23
|
+
leading_p:bool=False,
|
|
24
|
+
leading_axis_p:str='y',
|
|
25
|
+
leading_y:bool=True,
|
|
26
|
+
leading_axis_y:str='x',
|
|
27
|
+
yield_gap:int=10,
|
|
28
|
+
fps:int=30,
|
|
29
|
+
ref_point='bc',
|
|
30
|
+
ref_offset:tuple=(0,0),
|
|
31
|
+
filter_buffer:int=0,
|
|
32
|
+
p_zone:Polygon=None,
|
|
33
|
+
y_zone:Polygon=None) -> None:
|
|
19
34
|
'''
|
|
20
35
|
Parameters:
|
|
21
36
|
threshold: the hyperparameter to determine if a yield event (frame difference <=yield_gap*fps), default is 3 seconds
|
dnt/analysis/stop3.py
CHANGED
|
@@ -6,11 +6,15 @@ from tqdm import tqdm
|
|
|
6
6
|
from cython_bbox import bbox_overlaps
|
|
7
7
|
import numpy as np
|
|
8
8
|
#import mapply
|
|
9
|
-
|
|
9
|
+
import sys
|
|
10
|
+
import os
|
|
11
|
+
sys.path.append(os.path.dirname(__file__))
|
|
12
|
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
|
13
|
+
from filter import Filter
|
|
10
14
|
from matplotlib import pyplot as plt
|
|
11
15
|
import random, os
|
|
12
|
-
from
|
|
13
|
-
from
|
|
16
|
+
from label.labeler2 import Labeler
|
|
17
|
+
from engine.bbox_iou import ious
|
|
14
18
|
|
|
15
19
|
class StopAnalyzer():
|
|
16
20
|
def __init__(self,
|
dnt/detect/signal/detector.py
CHANGED
|
@@ -12,19 +12,27 @@ import torchvision.transforms as transforms
|
|
|
12
12
|
from torchvision.models.resnet import ResNet18_Weights
|
|
13
13
|
from shared.download import download_file
|
|
14
14
|
from matplotlib import pyplot as plt
|
|
15
|
+
from ultralytics import YOLO
|
|
15
16
|
|
|
16
17
|
class Model(nn.Module):
|
|
17
|
-
def __init__(self):
|
|
18
|
+
def __init__(self, num_class=2):
|
|
18
19
|
super(Model, self).__init__()
|
|
19
20
|
|
|
20
21
|
self.resnet18 = models.resnet18(weights=ResNet18_Weights.DEFAULT)
|
|
21
|
-
self.resnet18.fc = nn.Linear(512,
|
|
22
|
+
self.resnet18.fc = nn.Linear(512, num_class)
|
|
22
23
|
|
|
23
24
|
def forward(self, x):
|
|
24
25
|
return self.resnet18(x)
|
|
25
26
|
|
|
26
27
|
class SignalDetector:
|
|
27
|
-
def __init__(self,
|
|
28
|
+
def __init__(self,
|
|
29
|
+
det_zones:list,
|
|
30
|
+
model:str='ped',
|
|
31
|
+
weights:str=None,
|
|
32
|
+
batchsz:int=64,
|
|
33
|
+
num_class:int=2,
|
|
34
|
+
threshold:float=0.98,
|
|
35
|
+
device='auto'):
|
|
28
36
|
'''
|
|
29
37
|
Detect traffic signal status
|
|
30
38
|
|
|
@@ -33,6 +41,7 @@ class SignalDetector:
|
|
|
33
41
|
- model: detection model, default is 'ped', 'custom'
|
|
34
42
|
- weights: path of weights, default is None
|
|
35
43
|
- batchsz: the batch size for prediction, default is 64
|
|
44
|
+
- num_class: the number of classes, default is 2
|
|
36
45
|
- threshold: the threshold for detection, default is 0.98
|
|
37
46
|
'''
|
|
38
47
|
|
|
@@ -47,7 +56,7 @@ class SignalDetector:
|
|
|
47
56
|
url = 'https://its.cutr.usf.edu/alms/downloads/ped_signal.pt'
|
|
48
57
|
download_file(url, weights)
|
|
49
58
|
|
|
50
|
-
self.model = Model()
|
|
59
|
+
self.model = Model(num_class)
|
|
51
60
|
if device == 'auto':
|
|
52
61
|
self.device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
|
|
53
62
|
elif device == 'cuda' and torch.cuda.is_available():
|
dnt/detect/timestamp.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
import easyocr
|
|
2
|
+
import cv2
|
|
3
|
+
import numpy as np
|
|
4
|
+
import dateparser
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
|
|
7
|
+
class TimestampExtractor:
|
|
8
|
+
def __init__(self,
|
|
9
|
+
zone:np.ndarray=None,
|
|
10
|
+
allowlist='0123456789-:/'):
|
|
11
|
+
"""
|
|
12
|
+
args:
|
|
13
|
+
zone (np.ndarray): The coordinates of the timestamp area in the image.
|
|
14
|
+
allowlist (str): Characters to allow in the OCR process.
|
|
15
|
+
"""
|
|
16
|
+
self.zone = zone
|
|
17
|
+
self.reader = easyocr.Reader(['en'])
|
|
18
|
+
self.allowlist = allowlist
|
|
19
|
+
|
|
20
|
+
def extract_timestamp(self,
|
|
21
|
+
img:np.ndarray,
|
|
22
|
+
gray:bool=False,
|
|
23
|
+
) -> datetime:
|
|
24
|
+
"""
|
|
25
|
+
Extract timestamp from the image using OCR.
|
|
26
|
+
args:
|
|
27
|
+
img (np.ndarray): The input image from which to extract the timestamp.
|
|
28
|
+
returns:
|
|
29
|
+
datetime: The extracted timestamp.
|
|
30
|
+
"""
|
|
31
|
+
# Ensure the zone is a numpy array
|
|
32
|
+
if self.zone is None:
|
|
33
|
+
crop_img = img
|
|
34
|
+
else:
|
|
35
|
+
x1, y1, x2, y2 = self.zone[0][0], self.zone[0][1], self.zone[2][0], self.zone[2][1]
|
|
36
|
+
crop_img = img[y1:y2, x1:x2]
|
|
37
|
+
|
|
38
|
+
if gray:
|
|
39
|
+
crop_img = cv2.cvtColor(crop_img, cv2.COLOR_BGR2GRAY)
|
|
40
|
+
else:
|
|
41
|
+
crop_img = crop_img
|
|
42
|
+
|
|
43
|
+
result = self.reader.readtext(crop_img, detail=0, allowlist='0123456789-:/')
|
|
44
|
+
dt = dateparser.parse(" ".join(result))
|
|
45
|
+
return dt
|
|
46
|
+
|
|
47
|
+
def extract_timestamp_video(self,
|
|
48
|
+
video_path:str,
|
|
49
|
+
frame:int=0,
|
|
50
|
+
gray:bool=False,
|
|
51
|
+
) -> datetime:
|
|
52
|
+
"""
|
|
53
|
+
Extract timestamp from a video file using OCR.
|
|
54
|
+
args:
|
|
55
|
+
video_path (str): The path to the video file.
|
|
56
|
+
frame (int): The frame index to extract the timestamp from.
|
|
57
|
+
gray (bool): Whether to convert the image to grayscale before OCR.
|
|
58
|
+
returns:
|
|
59
|
+
datetime: The extracted timestamp.
|
|
60
|
+
"""
|
|
61
|
+
cap = cv2.VideoCapture(video_path)
|
|
62
|
+
cap.set(cv2.CAP_PROP_POS_FRAMES, frame)
|
|
63
|
+
ret, img = cap.read()
|
|
64
|
+
cap.release()
|
|
65
|
+
|
|
66
|
+
if not ret:
|
|
67
|
+
raise ValueError("Could not read video file")
|
|
68
|
+
|
|
69
|
+
return self.extract_timestamp(img, gray=gray)
|
|
70
|
+
|
|
71
|
+
def extract_timestamp_video_auto(self,
|
|
72
|
+
video_path:str,
|
|
73
|
+
gray:bool=False,
|
|
74
|
+
) -> tuple:
|
|
75
|
+
"""
|
|
76
|
+
Extract the initial timestamp from the first readable frame in a video file using OCR.
|
|
77
|
+
args:
|
|
78
|
+
video_path (str): The path to the video file.
|
|
79
|
+
gray (bool): Whether to convert the image to grayscale before OCR.
|
|
80
|
+
returns:
|
|
81
|
+
dt, frame: The extracted timestamp and the frame index.
|
|
82
|
+
"""
|
|
83
|
+
cap = cv2.VideoCapture(video_path)
|
|
84
|
+
cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
|
85
|
+
tot = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
86
|
+
|
|
87
|
+
for frame in range(0, tot-1):
|
|
88
|
+
ret, img = cap.read()
|
|
89
|
+
if not ret:
|
|
90
|
+
break
|
|
91
|
+
dt = self.extract_timestamp(img, gray=gray)
|
|
92
|
+
if dt:
|
|
93
|
+
cap.release()
|
|
94
|
+
return dt, frame
|
|
95
|
+
|
|
96
|
+
cap.release()
|
|
97
|
+
return None, None
|
|
98
|
+
|
|
99
|
+
if __name__ == "__main__":
|
|
100
|
+
img = cv2.imread('/mnt/d/videos/sample/frames/007_day.png')
|
|
101
|
+
zone = np.array([[62, 1003], [543, 1000], [547, 1059], [67, 1061]])
|
|
102
|
+
extractor = TimestampExtractor(zone=zone)
|
|
103
|
+
timestamp = extractor.extract_timestamp(img)
|
|
104
|
+
print("Extracted Timestamp:")
|
|
105
|
+
print(timestamp)
|
dnt/detect/yolov8/detector.py
CHANGED
|
@@ -1,3 +1,6 @@
|
|
|
1
|
+
'''
|
|
2
|
+
This module is a wrapper for ultralytics YOLOV8 and RTDETR models.
|
|
3
|
+
'''
|
|
1
4
|
from ultralytics import YOLO, RTDETR
|
|
2
5
|
import cv2
|
|
3
6
|
import pandas as pd
|
|
@@ -7,21 +10,38 @@ from pathlib import Path
|
|
|
7
10
|
import torch
|
|
8
11
|
|
|
9
12
|
class Detector:
|
|
10
|
-
def __init__(self,
|
|
11
|
-
|
|
13
|
+
def __init__(self,
|
|
14
|
+
model:str='yolo',
|
|
15
|
+
weights:str='x',
|
|
16
|
+
conf:float=0.25,
|
|
17
|
+
nms:float=0.7,
|
|
18
|
+
max_det:int=300,
|
|
19
|
+
device:str='auto',
|
|
20
|
+
half:bool=False):
|
|
12
21
|
'''
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
22
|
+
Initialize the Detector class.
|
|
23
|
+
Parameters:
|
|
24
|
+
model: 'yolo' (default), 'rtdetr'
|
|
25
|
+
weights: 'x' - extra-large (default), 'l' - large, 'm' - mdium, 's' - small, 'n' - nano, '.pt' - custom model weights
|
|
26
|
+
conf: 0.25 (default) - confidence threshold
|
|
27
|
+
nms: 0.7 (default) - non-max suppression threshold
|
|
28
|
+
max_det: 300 (default) - maximum number of detections per image
|
|
29
|
+
device: 'auto' (default), 'cuda', 'cpu', 'mps'
|
|
30
|
+
half: False (default), True
|
|
17
31
|
'''
|
|
18
|
-
# Load
|
|
32
|
+
# Load model
|
|
19
33
|
cwd = Path(__file__).parent.absolute()
|
|
20
34
|
if model == 'yolo':
|
|
21
35
|
if weights in ['x', 'l', 'm', 's', 'n']:
|
|
22
36
|
model_path = os.path.join(cwd, 'models/yolov8'+weights+'.pt')
|
|
23
37
|
elif ".pt" in weights:
|
|
24
38
|
model_path = os.path.join(cwd, 'models/'+weights)
|
|
39
|
+
elif model == 'yolo11':
|
|
40
|
+
if weights in ['x', 'l', 'm', 's', 'n']:
|
|
41
|
+
model_path = os.path.join(cwd, 'models/yolo11'+weights+'.pt')
|
|
42
|
+
elif ".pt" in weights:
|
|
43
|
+
model_path = os.path.join(cwd, 'models/'+weights)
|
|
44
|
+
|
|
25
45
|
elif model == 'rtdetr':
|
|
26
46
|
if weights in ['x', 'l']:
|
|
27
47
|
model_path = os.path.join(cwd, 'models/rtdetr-'+weights+'.pt')
|
|
@@ -39,7 +59,7 @@ class Detector:
|
|
|
39
59
|
else:
|
|
40
60
|
raise Exception('Invalid detection model type!')
|
|
41
61
|
|
|
42
|
-
if model == 'yolo':
|
|
62
|
+
if model == 'yolo' or model == 'yolo11':
|
|
43
63
|
self.model = YOLO(model_path)
|
|
44
64
|
elif model == 'rtdetr':
|
|
45
65
|
self.model = RTDETR(model_path)
|
|
@@ -63,10 +83,27 @@ class Detector:
|
|
|
63
83
|
|
|
64
84
|
self.half = half
|
|
65
85
|
|
|
66
|
-
def detect(self,
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
86
|
+
def detect(self,
|
|
87
|
+
input_video:str,
|
|
88
|
+
iou_file:str=None,
|
|
89
|
+
video_index:int=None,
|
|
90
|
+
video_tot:int=None,
|
|
91
|
+
start_frame:int=None,
|
|
92
|
+
end_frame:int=None,
|
|
93
|
+
verbose:bool=True) -> pd.DataFrame:
|
|
94
|
+
'''
|
|
95
|
+
Detect objects in a video file.
|
|
96
|
+
Parameters:
|
|
97
|
+
input_video: path to the input video file
|
|
98
|
+
iou_file: path to the output file
|
|
99
|
+
video_index: index of the video in the batch
|
|
100
|
+
video_tot: total number of videos in the batch
|
|
101
|
+
start_frame: start frame number
|
|
102
|
+
end_frame: end frame number
|
|
103
|
+
verbose: True (default), False
|
|
104
|
+
Returns:
|
|
105
|
+
df: pandas DataFrame containing the detection results
|
|
106
|
+
'''
|
|
70
107
|
# open input video
|
|
71
108
|
cap = cv2.VideoCapture(input_video)
|
|
72
109
|
if not cap.isOpened():
|
|
@@ -76,16 +113,14 @@ class Detector:
|
|
|
76
113
|
results = []
|
|
77
114
|
|
|
78
115
|
video_fps = int(cap.get(cv2.CAP_PROP_FPS)) #original fps
|
|
79
|
-
if
|
|
80
|
-
start_frame
|
|
81
|
-
if start_frame > int(cap.get(cv2.CAP_PROP_FRAME_COUNT))-1:
|
|
116
|
+
if start_frame:
|
|
117
|
+
if (start_frame > int(cap.get(cv2.CAP_PROP_FRAME_COUNT))-1) or (start_frame < 0):
|
|
82
118
|
start_frame = 0
|
|
83
119
|
else:
|
|
84
120
|
start_frame = 0
|
|
85
121
|
|
|
86
|
-
if
|
|
87
|
-
end_frame
|
|
88
|
-
if end_frame > int(cap.get(cv2.CAP_PROP_FRAME_COUNT))-1:
|
|
122
|
+
if end_frame:
|
|
123
|
+
if (end_frame > int(cap.get(cv2.CAP_PROP_FRAME_COUNT))-1) or (end_frame < 0):
|
|
89
124
|
end_frame = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))-1
|
|
90
125
|
else:
|
|
91
126
|
end_frame = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))-1
|
|
@@ -93,12 +128,13 @@ class Detector:
|
|
|
93
128
|
if start_frame>=end_frame:
|
|
94
129
|
raise Exception('The given start time exceeds the end time!')
|
|
95
130
|
|
|
96
|
-
frame_total = end_frame - start_frame
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
131
|
+
frame_total = end_frame - start_frame
|
|
132
|
+
if verbose:
|
|
133
|
+
pbar = tqdm(total=frame_total, unit=" frames")
|
|
134
|
+
if video_index and video_tot:
|
|
135
|
+
pbar.set_description_str("Detecting {} of {}".format(video_index, video_tot))
|
|
136
|
+
else:
|
|
137
|
+
pbar.set_description_str("Detecting ")
|
|
102
138
|
|
|
103
139
|
cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
|
|
104
140
|
while cap.isOpened():
|
|
@@ -120,9 +156,11 @@ class Detector:
|
|
|
120
156
|
result.insert(1, 'res', -1)
|
|
121
157
|
results.append(result)
|
|
122
158
|
|
|
123
|
-
|
|
159
|
+
if verbose:
|
|
160
|
+
pbar.update()
|
|
124
161
|
|
|
125
|
-
|
|
162
|
+
if verbose:
|
|
163
|
+
pbar.close()
|
|
126
164
|
cap.release()
|
|
127
165
|
#cv2.destroyAllWindows()
|
|
128
166
|
|
|
@@ -179,7 +217,11 @@ class Detector:
|
|
|
179
217
|
pbar.close()
|
|
180
218
|
cap.release()
|
|
181
219
|
|
|
182
|
-
|
|
220
|
+
if len(results) == 0:
|
|
221
|
+
return pd.DataFrame(columns=['frame', 'res', 'x', 'y', 'w', 'h', 'conf', 'class'])
|
|
222
|
+
|
|
223
|
+
df = pd.concat(d for d in results if not d.empty) # remove empty dataframes
|
|
224
|
+
#df = pd.concat(results)
|
|
183
225
|
df['x'] = round(df['x'], 1)
|
|
184
226
|
df['y'] = round(df['y'], 1)
|
|
185
227
|
df['w'] = round(df['x2']-df['x'], 0)
|
|
@@ -271,6 +313,7 @@ class Detector:
|
|
|
271
313
|
|
|
272
314
|
if __name__=='__main__':
|
|
273
315
|
|
|
274
|
-
detector = Detector()
|
|
275
|
-
result = detector.
|
|
316
|
+
detector = Detector(model='yolo11')
|
|
317
|
+
result = detector.detect('/mnt/d/videos/sample/traffic.mp4', verbose=True)
|
|
318
|
+
|
|
276
319
|
print(result)
|
dnt/detect/yolov8/segmentor.py
CHANGED
|
@@ -4,15 +4,18 @@ import pandas as pd
|
|
|
4
4
|
from tqdm import tqdm
|
|
5
5
|
import random, os, json
|
|
6
6
|
from pathlib import Path
|
|
7
|
+
import torch
|
|
8
|
+
from shapely.geometry import Polygon as Pol
|
|
7
9
|
|
|
8
10
|
class Segmentor:
|
|
9
|
-
def __init__(self, model:str='
|
|
11
|
+
def __init__(self, model:str='yolov8m-seg.pt', conf:float=0.25, nms:float=0.7, max_det:int=300, device:str='auto', half:bool=False):
|
|
10
12
|
'''
|
|
11
13
|
yolo: x - yolov8x-seg.pt
|
|
12
14
|
l - yolov8l-seg.pt
|
|
13
15
|
m - yolov8m-seg.pt
|
|
14
16
|
s - yolov8s-seg.pt
|
|
15
17
|
n - yolov8n-seg.pt
|
|
18
|
+
|
|
16
19
|
'''
|
|
17
20
|
# Load YOLOV8 model
|
|
18
21
|
cwd = Path(__file__).parent.absolute()
|
|
@@ -21,6 +24,17 @@ class Segmentor:
|
|
|
21
24
|
self.conf = conf
|
|
22
25
|
self.nms = nms
|
|
23
26
|
self.max_det = max_det
|
|
27
|
+
|
|
28
|
+
if device == 'auto':
|
|
29
|
+
self.device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
|
|
30
|
+
elif device == 'cuda' and torch.cuda.is_available():
|
|
31
|
+
self.device = 'cuda'
|
|
32
|
+
elif device == 'mps' and torch.backends.mps.is_available():
|
|
33
|
+
self.device = 'mps'
|
|
34
|
+
else:
|
|
35
|
+
self.device = 'cpu'
|
|
36
|
+
|
|
37
|
+
self.half = half
|
|
24
38
|
self.half = half
|
|
25
39
|
|
|
26
40
|
def segment(self, input_video:str, mask_file:str=None,
|
|
@@ -106,8 +120,52 @@ class Segmentor:
|
|
|
106
120
|
video_index=None, video_tot=None,
|
|
107
121
|
start_frame=frame, end_frame=frame,
|
|
108
122
|
verbose=False)
|
|
123
|
+
|
|
124
|
+
def segment_zone(self, input_video:str, frame:int=None, zone:list=None) -> list:
|
|
125
|
+
'''
|
|
126
|
+
Segment the video in a specific zone
|
|
127
|
+
input_video: Path to video
|
|
128
|
+
frame: Frame number to segment
|
|
129
|
+
zone: Polygon zone to segment
|
|
130
|
+
Return:
|
|
131
|
+
List of segmented objects in the zone
|
|
132
|
+
'''
|
|
133
|
+
|
|
134
|
+
# open input video
|
|
135
|
+
cap = cv2.VideoCapture(input_video)
|
|
136
|
+
if not cap.isOpened():
|
|
137
|
+
print('Failed to open the video!')
|
|
109
138
|
|
|
139
|
+
if frame <0 or frame > int(cap.get(cv2.CAP_PROP_FRAME_COUNT))-1:
|
|
140
|
+
raise Exception('The given frame is out of range!')
|
|
141
|
+
|
|
142
|
+
cap.set(cv2.CAP_PROP_POS_FRAMES, frame)
|
|
143
|
+
ret, frame = cap.read()
|
|
144
|
+
if ret:
|
|
145
|
+
x1, y1, x2, y2 = zone
|
|
146
|
+
cropped_frame = frame[y1:y2, x1:x2]
|
|
147
|
+
detects = self.model.predict(cropped_frame, verbose=False, conf=self.conf, iou=self.nms, max_det=self.max_det, half=self.half)
|
|
148
|
+
if len(detects)>0:
|
|
149
|
+
if detects[0].masks is not None:
|
|
150
|
+
xy = list(map(lambda x: x.tolist(), detects[0].masks.xy))
|
|
151
|
+
print(xy)
|
|
152
|
+
input('.')
|
|
153
|
+
cap.release()
|
|
154
|
+
|
|
155
|
+
|
|
110
156
|
if __name__=="__main__":
|
|
111
157
|
|
|
112
158
|
segmentor = Segmentor()
|
|
113
|
-
segmentor.segment('/mnt/d/videos/samples/traffic_short.mp4', '/mnt/d/videos/samples/traffic_short_mask.json', start_frame=1, end_frame=1)
|
|
159
|
+
#segmentor.segment('/mnt/d/videos/samples/traffic_short.mp4', '/mnt/d/videos/samples/traffic_short_mask.json', start_frame=1, end_frame=1)
|
|
160
|
+
input_video = '/mnt/d/videos/hyde/after_2024-08-15_nw.mp4'
|
|
161
|
+
tracks = pd.read_csv('/mnt/d/videos/hyde/tracks/after_2024-08-15_nw_track_veh_rt.txt')
|
|
162
|
+
tracks.columns = ['frame', 'track', 'x', 'y', 'w', 'h', 'r1', 'r2', 'r3', 'r4']
|
|
163
|
+
dets = tracks[tracks['track'] == 678]
|
|
164
|
+
dets.sort_values(by='frame', inplace=True)
|
|
165
|
+
|
|
166
|
+
for index, det in dets.iterrows():
|
|
167
|
+
zone = (int(det['x']), int(det['y']), int(det['x'] + det['w']), int(det['y'] + det['h']))
|
|
168
|
+
frame = int(det['frame'])
|
|
169
|
+
|
|
170
|
+
segmentor.segment_zone(input_video, frame, zone) # Segment the specific zone
|
|
171
|
+
print(f"Segmented frame {frame}")
|