dnt 0.3.1.3__py3-none-any.whl → 0.3.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dnt might be problematic. Click here for more details.

dnt/__init__.py CHANGED
@@ -2,4 +2,4 @@ import sys, os
2
2
  sys.path.append(os.path.dirname(__file__))
3
3
  sys.path.append(os.path.join(os.path.dirname(__file__), 'third_party/fast-reid'))
4
4
 
5
- __version__='0.3.1.3'
5
+ __version__='0.3.1.7'
dnt/analysis/__init__.py CHANGED
@@ -3,6 +3,6 @@ import sys
3
3
 
4
4
  sys.path.append(os.path.dirname(__file__))
5
5
 
6
- from .stop3 import StopAnalyzer
7
- from .count import Count
8
- from .interaction import YieldAnalyzer
6
+ from stop3 import StopAnalyzer
7
+ from count import Count
8
+ from interaction2 import YieldAnalyzer
dnt/analysis/count.py CHANGED
@@ -1,54 +1,73 @@
1
- from shapely.geometry import Point, Polygon, LineString, box
2
- import geopandas as gpd, pandas as pd
3
- import datetime
1
+ # pylint: disable=too-many-arguments
2
+ '''
3
+ This script is used to count the number of tracks that passing a reference line for post-analysis
4
+ last modified: 2021-09-30
5
+ '''
6
+ from shapely.geometry import Polygon, LineString
7
+ import geopandas as gpd
8
+ import pandas as pd
4
9
  from tqdm import tqdm
5
10
  import numpy as np
6
11
 
7
12
  class Count:
8
- def __init__(self, track_fields: list[str] = ['frame', 'track_id', 'x', 'y', 'w', 'h', 'r1', 'r2', 'r3', 'r4']) -> None:
9
- self.track_fields = track_fields
13
+ '''
14
+ Count tracks that passing a reference line for post-analysis
15
+ Methods:
16
+ __init__(): initialize the class
17
+ count_tracks_by_line():
18
+ count tracks that passing a reference line for post-analysis
19
+ '''
20
+ def __init__(self) -> None:
21
+ self.track_fields = ['frame', 'track_id', 'x', 'y', 'w', 'h', 'r1', 'r2', 'r3', 'r4']
10
22
 
11
- def count_tracks_by_line_post(self, tracks:pd.DataFrame = None, track_file:str=None,
12
- line: LineString = None,
13
- video_index:int = None, video_tot:int = None) -> pd.DataFrame:
23
+ def count_tracks_by_line(self,
24
+ tracks:pd.DataFrame=None,
25
+ track_file:str=None,
26
+ line: LineString=None,
27
+ video_index:int=None,
28
+ video_tot:int=None,
29
+ verbose:bool=True) -> pd.DataFrame:
14
30
  '''
15
31
  Count tracks that passing a reference line for post-analysis
16
32
  Inputs:
17
- tracks - a DataFrame of tracks, [FRAME, TRACK_ID, TOPX, TOPY, WIDTH, LENGTH, RESERVED, RESERVED, RESERVED, RESERVED],
18
- if None, read track_file
33
+ tracks - a DataFrame of tracks, if None (default), read track_file
19
34
  track_file - a txt file contains tracks,
20
- line - the reference line
35
+ line - the reference line, LineString(pointA, pointB)
21
36
  video_index - the index of video for processing
22
37
  video_tot - the total number of videos
38
+ verbose - if True, show progress bar
23
39
  Return:
24
40
  A DataFrame of track_id, frame, direction, count
25
41
  '''
26
-
42
+ # Load tracks
27
43
  if tracks is None:
28
44
  tracks = pd.read_csv(track_file, header=None)
29
-
30
45
  if len(tracks.columns) != len(self.track_fields):
31
46
  raise Exception("The tracks format is incorrect!")
32
-
33
47
  tracks.columns = self.track_fields
34
48
  track_ids = tracks['track_id'].unique()
35
49
 
36
- g = tracks.apply(lambda track: Polygon([(track['x'], track['y']), (track['x'] + track['w'], track['y']),
37
- (track['x'] + track['w'], track['y'] + track['h']), (track['x'], track['y'] + track['h'])]), axis =1)
38
- geo_tracks = gpd.GeoDataFrame(tracks, geometry=g)
39
-
40
- a = np.array(line.coords[0])
41
- b = np.array(line.coords[1])
42
-
43
- pbar = tqdm(total=len(track_ids), unit=' tracks')
44
- if video_index and video_tot:
45
- pbar.set_description_str("Counting {} of {}".format(video_index, video_tot))
46
- else:
47
- pbar.set_description_str("Counting")
50
+ # Create a GeoDataFrame of tracks
51
+ geo = tracks.apply(lambda track: Polygon([(track['x'], track['y']),
52
+ (track['x'] + track['w'], track['y']),
53
+ (track['x'] + track['w'], track['y'] + track['h']),
54
+ (track['x'], track['y'] + track['h'])]),
55
+ axis =1)
56
+ geo_tracks = gpd.GeoDataFrame(tracks, geometry=geo)
57
+ point_a = np.array(line.coords[0])
58
+ point_b = np.array(line.coords[1])
48
59
 
60
+ # Interate through all tracks
61
+ if verbose:
62
+ pbar = tqdm(total=len(track_ids), unit=' tracks')
63
+ if video_index and video_tot:
64
+ pbar.set_description_str("Counting {} of {}".format(video_index, video_tot))
65
+ else:
66
+ pbar.set_description_str("Counting")
49
67
  intersected_tracks = []
50
68
  intersected_frames = []
51
- intersected_direct = []
69
+ intersected_direct = []
70
+
52
71
  for track_id in track_ids:
53
72
  selected = geo_tracks.loc[(geo_tracks['track_id']==track_id)].copy()
54
73
  if len(selected)>0:
@@ -60,19 +79,19 @@ class Count:
60
79
  intersected_tracks.append(track_id)
61
80
  frame_pos = int(len(intersected)/2)
62
81
  intersected_frames.append(intersected['frame'].values[frame_pos])
63
-
82
+
64
83
  # center point of the first frame
65
- p = np.array((intersected.iloc[0]['x'] + intersected.iloc[0]['w']/2,
84
+ point_c = np.array((intersected.iloc[0]['x'] + intersected.iloc[0]['w']/2,
66
85
  intersected.iloc[0]['y'] + intersected.iloc[0]['h']/2))
67
86
  d = 2 # right
68
- if np.cross(p-a, b-a) < 0:
87
+ if np.cross(point_c-point_a, point_b-point_a) < 0:
69
88
  d = 1 # left
70
-
71
89
  intersected_direct.append(d)
72
-
73
- pbar.update()
74
-
75
- pbar.close()
90
+ if verbose:
91
+ pbar.update()
92
+
93
+ if verbose:
94
+ pbar.close()
76
95
 
77
96
  results = pd.DataFrame(
78
97
  {
@@ -86,5 +105,3 @@ class Count:
86
105
  results['count'] = results.index + 1
87
106
 
88
107
  return results
89
-
90
-
@@ -1,21 +1,36 @@
1
+ """
2
+ This module is used to analyze the interaction between two tracks.
3
+ Author: Zhenyu Wang
4
+ Date: 2023/10/20
5
+ Email: wonstran@hotmail.com
6
+ Updates:
7
+ 2025/03/19: Add the function of nearmiss based on PET
8
+ 2023/10/20: Initial version
9
+ """
1
10
  from shapely.geometry import Point, Polygon, LineString, box
2
11
  from shapely import intersection, distance, intersects
3
12
  from shapelysmooth import taubin_smooth, chaikin_smooth, catmull_rom_smooth
4
13
  import geopandas as gpd, pandas as pd
5
14
  from tqdm import tqdm
6
15
  import numpy as np
7
- from matplotlib import pyplot as plt
8
16
  from dnt.label.labeler2 import Labeler
9
17
  import os
10
- from ..track import Tracker
11
18
 
12
19
  class YieldAnalyzer:
13
- def __init__(self, waiting_dist_p:int=600, waiting_dist_y:int=300,
14
- leading_p:bool=False, leading_axis_p:str='y',
15
- leading_y:bool=True, leading_axis_y:str='x',
16
- yield_gap:int=10, fps:int=30, ref_point='bc',
17
- ref_offset:tuple=(0,0), filter_buffer:int=0,
18
- p_zone:Polygon=None, y_zone:Polygon=None) -> None:
20
+ def __init__(self,
21
+ waiting_dist_p:int=600,
22
+ waiting_dist_y:int=300,
23
+ leading_p:bool=False,
24
+ leading_axis_p:str='y',
25
+ leading_y:bool=True,
26
+ leading_axis_y:str='x',
27
+ yield_gap:int=10,
28
+ fps:int=30,
29
+ ref_point='bc',
30
+ ref_offset:tuple=(0,0),
31
+ filter_buffer:int=0,
32
+ p_zone:Polygon=None,
33
+ y_zone:Polygon=None) -> None:
19
34
  '''
20
35
  Parameters:
21
36
  threshold: the hyperparameter to determine if a yield event (frame difference <=yield_gap*fps), default is 3 seconds
dnt/analysis/stop3.py CHANGED
@@ -6,11 +6,15 @@ from tqdm import tqdm
6
6
  from cython_bbox import bbox_overlaps
7
7
  import numpy as np
8
8
  #import mapply
9
- from ..filter import Filter
9
+ import sys
10
+ import os
11
+ sys.path.append(os.path.dirname(__file__))
12
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
13
+ from filter import Filter
10
14
  from matplotlib import pyplot as plt
11
15
  import random, os
12
- from ..label.labeler2 import Labeler
13
- from ..engine.bbox_iou import ious
16
+ from label.labeler2 import Labeler
17
+ from engine.bbox_iou import ious
14
18
 
15
19
  class StopAnalyzer():
16
20
  def __init__(self,
@@ -12,19 +12,27 @@ import torchvision.transforms as transforms
12
12
  from torchvision.models.resnet import ResNet18_Weights
13
13
  from shared.download import download_file
14
14
  from matplotlib import pyplot as plt
15
+ from ultralytics import YOLO
15
16
 
16
17
  class Model(nn.Module):
17
- def __init__(self):
18
+ def __init__(self, num_class=2):
18
19
  super(Model, self).__init__()
19
20
 
20
21
  self.resnet18 = models.resnet18(weights=ResNet18_Weights.DEFAULT)
21
- self.resnet18.fc = nn.Linear(512, 2)
22
+ self.resnet18.fc = nn.Linear(512, num_class)
22
23
 
23
24
  def forward(self, x):
24
25
  return self.resnet18(x)
25
26
 
26
27
  class SignalDetector:
27
- def __init__(self, det_zones:list, model:str='ped', weights:str=None, batchsz:int=64, threshold:float=0.98, device='auto'):
28
+ def __init__(self,
29
+ det_zones:list,
30
+ model:str='ped',
31
+ weights:str=None,
32
+ batchsz:int=64,
33
+ num_class:int=2,
34
+ threshold:float=0.98,
35
+ device='auto'):
28
36
  '''
29
37
  Detect traffic signal status
30
38
 
@@ -33,6 +41,7 @@ class SignalDetector:
33
41
  - model: detection model, default is 'ped', 'custom'
34
42
  - weights: path of weights, default is None
35
43
  - batchsz: the batch size for prediction, default is 64
44
+ - num_class: the number of classes, default is 2
36
45
  - threshold: the threshold for detection, default is 0.98
37
46
  '''
38
47
 
@@ -47,7 +56,7 @@ class SignalDetector:
47
56
  url = 'https://its.cutr.usf.edu/alms/downloads/ped_signal.pt'
48
57
  download_file(url, weights)
49
58
 
50
- self.model = Model()
59
+ self.model = Model(num_class)
51
60
  if device == 'auto':
52
61
  self.device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
53
62
  elif device == 'cuda' and torch.cuda.is_available():
@@ -0,0 +1,105 @@
1
+ import easyocr
2
+ import cv2
3
+ import numpy as np
4
+ import dateparser
5
+ from datetime import datetime
6
+
7
+ class TimestampExtractor:
8
+ def __init__(self,
9
+ zone:np.ndarray=None,
10
+ allowlist='0123456789-:/'):
11
+ """
12
+ args:
13
+ zone (np.ndarray): The coordinates of the timestamp area in the image.
14
+ allowlist (str): Characters to allow in the OCR process.
15
+ """
16
+ self.zone = zone
17
+ self.reader = easyocr.Reader(['en'])
18
+ self.allowlist = allowlist
19
+
20
+ def extract_timestamp(self,
21
+ img:np.ndarray,
22
+ gray:bool=False,
23
+ ) -> datetime:
24
+ """
25
+ Extract timestamp from the image using OCR.
26
+ args:
27
+ img (np.ndarray): The input image from which to extract the timestamp.
28
+ returns:
29
+ datetime: The extracted timestamp.
30
+ """
31
+ # Ensure the zone is a numpy array
32
+ if self.zone is None:
33
+ crop_img = img
34
+ else:
35
+ x1, y1, x2, y2 = self.zone[0][0], self.zone[0][1], self.zone[2][0], self.zone[2][1]
36
+ crop_img = img[y1:y2, x1:x2]
37
+
38
+ if gray:
39
+ crop_img = cv2.cvtColor(crop_img, cv2.COLOR_BGR2GRAY)
40
+ else:
41
+ crop_img = crop_img
42
+
43
+ result = self.reader.readtext(crop_img, detail=0, allowlist='0123456789-:/')
44
+ dt = dateparser.parse(" ".join(result))
45
+ return dt
46
+
47
+ def extract_timestamp_video(self,
48
+ video_path:str,
49
+ frame:int=0,
50
+ gray:bool=False,
51
+ ) -> datetime:
52
+ """
53
+ Extract timestamp from a video file using OCR.
54
+ args:
55
+ video_path (str): The path to the video file.
56
+ frame (int): The frame index to extract the timestamp from.
57
+ gray (bool): Whether to convert the image to grayscale before OCR.
58
+ returns:
59
+ datetime: The extracted timestamp.
60
+ """
61
+ cap = cv2.VideoCapture(video_path)
62
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame)
63
+ ret, img = cap.read()
64
+ cap.release()
65
+
66
+ if not ret:
67
+ raise ValueError("Could not read video file")
68
+
69
+ return self.extract_timestamp(img, gray=gray)
70
+
71
+ def extract_timestamp_video_auto(self,
72
+ video_path:str,
73
+ gray:bool=False,
74
+ ) -> tuple:
75
+ """
76
+ Extract the initial timestamp from the first readable frame in a video file using OCR.
77
+ args:
78
+ video_path (str): The path to the video file.
79
+ gray (bool): Whether to convert the image to grayscale before OCR.
80
+ returns:
81
+ dt, frame: The extracted timestamp and the frame index.
82
+ """
83
+ cap = cv2.VideoCapture(video_path)
84
+ cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
85
+ tot = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
86
+
87
+ for frame in range(0, tot-1):
88
+ ret, img = cap.read()
89
+ if not ret:
90
+ break
91
+ dt = self.extract_timestamp(img, gray=gray)
92
+ if dt:
93
+ cap.release()
94
+ return dt, frame
95
+
96
+ cap.release()
97
+ return None, None
98
+
99
+ if __name__ == "__main__":
100
+ img = cv2.imread('/mnt/d/videos/sample/frames/007_day.png')
101
+ zone = np.array([[62, 1003], [543, 1000], [547, 1059], [67, 1061]])
102
+ extractor = TimestampExtractor(zone=zone)
103
+ timestamp = extractor.extract_timestamp(img)
104
+ print("Extracted Timestamp:")
105
+ print(timestamp)
@@ -1,3 +1,6 @@
1
+ '''
2
+ This module is a wrapper for ultralytics YOLOV8 and RTDETR models.
3
+ '''
1
4
  from ultralytics import YOLO, RTDETR
2
5
  import cv2
3
6
  import pandas as pd
@@ -7,21 +10,38 @@ from pathlib import Path
7
10
  import torch
8
11
 
9
12
  class Detector:
10
- def __init__(self, model:str='yolo', weights:str='x', conf:float=0.25, nms:float=0.7, max_det:int=300,
11
- device:str='auto', half:bool=False):
13
+ def __init__(self,
14
+ model:str='yolo',
15
+ weights:str='x',
16
+ conf:float=0.25,
17
+ nms:float=0.7,
18
+ max_det:int=300,
19
+ device:str='auto',
20
+ half:bool=False):
12
21
  '''
13
- model: 'yolo' (default), 'rtdetr'
14
- weights: 'x' - extra-large (default), 'l' - large, 'm' - mdium, 's' - small, 'n' - nano, '.pt' - custom model weights
15
- device: 'auto' (default), 'cuda', 'cpu', 'mps'
16
- half: False (default), True
22
+ Initialize the Detector class.
23
+ Parameters:
24
+ model: 'yolo' (default), 'rtdetr'
25
+ weights: 'x' - extra-large (default), 'l' - large, 'm' - mdium, 's' - small, 'n' - nano, '.pt' - custom model weights
26
+ conf: 0.25 (default) - confidence threshold
27
+ nms: 0.7 (default) - non-max suppression threshold
28
+ max_det: 300 (default) - maximum number of detections per image
29
+ device: 'auto' (default), 'cuda', 'cpu', 'mps'
30
+ half: False (default), True
17
31
  '''
18
- # Load YOLOV8 model
32
+ # Load model
19
33
  cwd = Path(__file__).parent.absolute()
20
34
  if model == 'yolo':
21
35
  if weights in ['x', 'l', 'm', 's', 'n']:
22
36
  model_path = os.path.join(cwd, 'models/yolov8'+weights+'.pt')
23
37
  elif ".pt" in weights:
24
38
  model_path = os.path.join(cwd, 'models/'+weights)
39
+ elif model == 'yolo11':
40
+ if weights in ['x', 'l', 'm', 's', 'n']:
41
+ model_path = os.path.join(cwd, 'models/yolo11'+weights+'.pt')
42
+ elif ".pt" in weights:
43
+ model_path = os.path.join(cwd, 'models/'+weights)
44
+
25
45
  elif model == 'rtdetr':
26
46
  if weights in ['x', 'l']:
27
47
  model_path = os.path.join(cwd, 'models/rtdetr-'+weights+'.pt')
@@ -39,7 +59,7 @@ class Detector:
39
59
  else:
40
60
  raise Exception('Invalid detection model type!')
41
61
 
42
- if model == 'yolo':
62
+ if model == 'yolo' or model == 'yolo11':
43
63
  self.model = YOLO(model_path)
44
64
  elif model == 'rtdetr':
45
65
  self.model = RTDETR(model_path)
@@ -63,10 +83,27 @@ class Detector:
63
83
 
64
84
  self.half = half
65
85
 
66
- def detect(self, input_video:str, iou_file:str=None,
67
- video_index:int=None, video_tot:int=None,
68
- start_time:int=None, end_time:int=None, verbose:bool=False) -> pd.DataFrame:
69
-
86
+ def detect(self,
87
+ input_video:str,
88
+ iou_file:str=None,
89
+ video_index:int=None,
90
+ video_tot:int=None,
91
+ start_frame:int=None,
92
+ end_frame:int=None,
93
+ verbose:bool=True) -> pd.DataFrame:
94
+ '''
95
+ Detect objects in a video file.
96
+ Parameters:
97
+ input_video: path to the input video file
98
+ iou_file: path to the output file
99
+ video_index: index of the video in the batch
100
+ video_tot: total number of videos in the batch
101
+ start_frame: start frame number
102
+ end_frame: end frame number
103
+ verbose: True (default), False
104
+ Returns:
105
+ df: pandas DataFrame containing the detection results
106
+ '''
70
107
  # open input video
71
108
  cap = cv2.VideoCapture(input_video)
72
109
  if not cap.isOpened():
@@ -76,16 +113,14 @@ class Detector:
76
113
  results = []
77
114
 
78
115
  video_fps = int(cap.get(cv2.CAP_PROP_FPS)) #original fps
79
- if start_time:
80
- start_frame = int(video_fps * start_time)
81
- if start_frame > int(cap.get(cv2.CAP_PROP_FRAME_COUNT))-1:
116
+ if start_frame:
117
+ if (start_frame > int(cap.get(cv2.CAP_PROP_FRAME_COUNT))-1) or (start_frame < 0):
82
118
  start_frame = 0
83
119
  else:
84
120
  start_frame = 0
85
121
 
86
- if end_time:
87
- end_frame = int(video_fps * end_time)
88
- if end_frame > int(cap.get(cv2.CAP_PROP_FRAME_COUNT))-1:
122
+ if end_frame:
123
+ if (end_frame > int(cap.get(cv2.CAP_PROP_FRAME_COUNT))-1) or (end_frame < 0):
89
124
  end_frame = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))-1
90
125
  else:
91
126
  end_frame = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))-1
@@ -93,12 +128,13 @@ class Detector:
93
128
  if start_frame>=end_frame:
94
129
  raise Exception('The given start time exceeds the end time!')
95
130
 
96
- frame_total = end_frame - start_frame
97
- pbar = tqdm(total=frame_total, unit=" frames")
98
- if video_index and video_tot:
99
- pbar.set_description_str("Detecting {} of {}".format(video_index, video_tot))
100
- else:
101
- pbar.set_description_str("Detecting ")
131
+ frame_total = end_frame - start_frame
132
+ if verbose:
133
+ pbar = tqdm(total=frame_total, unit=" frames")
134
+ if video_index and video_tot:
135
+ pbar.set_description_str("Detecting {} of {}".format(video_index, video_tot))
136
+ else:
137
+ pbar.set_description_str("Detecting ")
102
138
 
103
139
  cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
104
140
  while cap.isOpened():
@@ -120,9 +156,11 @@ class Detector:
120
156
  result.insert(1, 'res', -1)
121
157
  results.append(result)
122
158
 
123
- pbar.update()
159
+ if verbose:
160
+ pbar.update()
124
161
 
125
- pbar.close()
162
+ if verbose:
163
+ pbar.close()
126
164
  cap.release()
127
165
  #cv2.destroyAllWindows()
128
166
 
@@ -179,7 +217,11 @@ class Detector:
179
217
  pbar.close()
180
218
  cap.release()
181
219
 
182
- df = pd.concat(results)
220
+ if len(results) == 0:
221
+ return pd.DataFrame(columns=['frame', 'res', 'x', 'y', 'w', 'h', 'conf', 'class'])
222
+
223
+ df = pd.concat(d for d in results if not d.empty) # remove empty dataframes
224
+ #df = pd.concat(results)
183
225
  df['x'] = round(df['x'], 1)
184
226
  df['y'] = round(df['y'], 1)
185
227
  df['w'] = round(df['x2']-df['x'], 0)
@@ -271,6 +313,7 @@ class Detector:
271
313
 
272
314
  if __name__=='__main__':
273
315
 
274
- detector = Detector()
275
- result = detector.detect_frames('/mnt/d/videos/samples/traffic_short.mp4', [0, 1])
316
+ detector = Detector(model='yolo11')
317
+ result = detector.detect('/mnt/d/videos/sample/traffic.mp4', verbose=True)
318
+
276
319
  print(result)
@@ -4,15 +4,18 @@ import pandas as pd
4
4
  from tqdm import tqdm
5
5
  import random, os, json
6
6
  from pathlib import Path
7
+ import torch
8
+ from shapely.geometry import Polygon as Pol
7
9
 
8
10
  class Segmentor:
9
- def __init__(self, model:str='yolov8x-seg.pt', conf:float=0.25, nms:float=0.7, max_det:int=300, half:bool=False):
11
+ def __init__(self, model:str='yolov8m-seg.pt', conf:float=0.25, nms:float=0.7, max_det:int=300, device:str='auto', half:bool=False):
10
12
  '''
11
13
  yolo: x - yolov8x-seg.pt
12
14
  l - yolov8l-seg.pt
13
15
  m - yolov8m-seg.pt
14
16
  s - yolov8s-seg.pt
15
17
  n - yolov8n-seg.pt
18
+
16
19
  '''
17
20
  # Load YOLOV8 model
18
21
  cwd = Path(__file__).parent.absolute()
@@ -21,6 +24,17 @@ class Segmentor:
21
24
  self.conf = conf
22
25
  self.nms = nms
23
26
  self.max_det = max_det
27
+
28
+ if device == 'auto':
29
+ self.device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
30
+ elif device == 'cuda' and torch.cuda.is_available():
31
+ self.device = 'cuda'
32
+ elif device == 'mps' and torch.backends.mps.is_available():
33
+ self.device = 'mps'
34
+ else:
35
+ self.device = 'cpu'
36
+
37
+ self.half = half
24
38
  self.half = half
25
39
 
26
40
  def segment(self, input_video:str, mask_file:str=None,
@@ -106,8 +120,52 @@ class Segmentor:
106
120
  video_index=None, video_tot=None,
107
121
  start_frame=frame, end_frame=frame,
108
122
  verbose=False)
123
+
124
+ def segment_zone(self, input_video:str, frame:int=None, zone:list=None) -> list:
125
+ '''
126
+ Segment the video in a specific zone
127
+ input_video: Path to video
128
+ frame: Frame number to segment
129
+ zone: Polygon zone to segment
130
+ Return:
131
+ List of segmented objects in the zone
132
+ '''
133
+
134
+ # open input video
135
+ cap = cv2.VideoCapture(input_video)
136
+ if not cap.isOpened():
137
+ print('Failed to open the video!')
109
138
 
139
+ if frame <0 or frame > int(cap.get(cv2.CAP_PROP_FRAME_COUNT))-1:
140
+ raise Exception('The given frame is out of range!')
141
+
142
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame)
143
+ ret, frame = cap.read()
144
+ if ret:
145
+ x1, y1, x2, y2 = zone
146
+ cropped_frame = frame[y1:y2, x1:x2]
147
+ detects = self.model.predict(cropped_frame, verbose=False, conf=self.conf, iou=self.nms, max_det=self.max_det, half=self.half)
148
+ if len(detects)>0:
149
+ if detects[0].masks is not None:
150
+ xy = list(map(lambda x: x.tolist(), detects[0].masks.xy))
151
+ print(xy)
152
+ input('.')
153
+ cap.release()
154
+
155
+
110
156
  if __name__=="__main__":
111
157
 
112
158
  segmentor = Segmentor()
113
- segmentor.segment('/mnt/d/videos/samples/traffic_short.mp4', '/mnt/d/videos/samples/traffic_short_mask.json', start_frame=1, end_frame=1)
159
+ #segmentor.segment('/mnt/d/videos/samples/traffic_short.mp4', '/mnt/d/videos/samples/traffic_short_mask.json', start_frame=1, end_frame=1)
160
+ input_video = '/mnt/d/videos/hyde/after_2024-08-15_nw.mp4'
161
+ tracks = pd.read_csv('/mnt/d/videos/hyde/tracks/after_2024-08-15_nw_track_veh_rt.txt')
162
+ tracks.columns = ['frame', 'track', 'x', 'y', 'w', 'h', 'r1', 'r2', 'r3', 'r4']
163
+ dets = tracks[tracks['track'] == 678]
164
+ dets.sort_values(by='frame', inplace=True)
165
+
166
+ for index, det in dets.iterrows():
167
+ zone = (int(det['x']), int(det['y']), int(det['x'] + det['w']), int(det['y'] + det['h']))
168
+ frame = int(det['frame'])
169
+
170
+ segmentor.segment_zone(input_video, frame, zone) # Segment the specific zone
171
+ print(f"Segmented frame {frame}")