nedo-vision-worker-core 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nedo-vision-worker-core might be problematic. Click here for more details.
- nedo_vision_worker_core/__init__.py +47 -12
- nedo_vision_worker_core/callbacks/DetectionCallbackManager.py +306 -0
- nedo_vision_worker_core/callbacks/DetectionCallbackTypes.py +150 -0
- nedo_vision_worker_core/callbacks/__init__.py +27 -0
- nedo_vision_worker_core/cli.py +24 -34
- nedo_vision_worker_core/core_service.py +121 -55
- nedo_vision_worker_core/database/DatabaseManager.py +2 -2
- nedo_vision_worker_core/detection/BaseDetector.py +2 -1
- nedo_vision_worker_core/detection/DetectionManager.py +2 -2
- nedo_vision_worker_core/detection/RFDETRDetector.py +23 -5
- nedo_vision_worker_core/detection/YOLODetector.py +18 -5
- nedo_vision_worker_core/detection/detection_processing/DetectionProcessor.py +1 -1
- nedo_vision_worker_core/detection/detection_processing/HumanDetectionProcessor.py +57 -3
- nedo_vision_worker_core/detection/detection_processing/PPEDetectionProcessor.py +173 -10
- nedo_vision_worker_core/models/ai_model.py +23 -2
- nedo_vision_worker_core/pipeline/PipelineProcessor.py +299 -14
- nedo_vision_worker_core/pipeline/PipelineSyncThread.py +32 -0
- nedo_vision_worker_core/repositories/PPEDetectionRepository.py +18 -15
- nedo_vision_worker_core/repositories/RestrictedAreaRepository.py +17 -13
- nedo_vision_worker_core/services/SharedVideoStreamServer.py +276 -0
- nedo_vision_worker_core/services/VideoSharingDaemon.py +808 -0
- nedo_vision_worker_core/services/VideoSharingDaemonManager.py +257 -0
- nedo_vision_worker_core/streams/SharedVideoDeviceManager.py +383 -0
- nedo_vision_worker_core/streams/StreamSyncThread.py +16 -2
- nedo_vision_worker_core/streams/VideoStream.py +267 -246
- nedo_vision_worker_core/streams/VideoStreamManager.py +158 -6
- nedo_vision_worker_core/tracker/TrackerManager.py +25 -31
- nedo_vision_worker_core-0.3.1.dist-info/METADATA +444 -0
- {nedo_vision_worker_core-0.2.0.dist-info → nedo_vision_worker_core-0.3.1.dist-info}/RECORD +32 -25
- nedo_vision_worker_core-0.2.0.dist-info/METADATA +0 -347
- {nedo_vision_worker_core-0.2.0.dist-info → nedo_vision_worker_core-0.3.1.dist-info}/WHEEL +0 -0
- {nedo_vision_worker_core-0.2.0.dist-info → nedo_vision_worker_core-0.3.1.dist-info}/entry_points.txt +0 -0
- {nedo_vision_worker_core-0.2.0.dist-info → nedo_vision_worker_core-0.3.1.dist-info}/top_level.txt +0 -0
|
@@ -3,13 +3,20 @@ import time
|
|
|
3
3
|
import signal
|
|
4
4
|
import os
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
from typing import Callable, Dict, List, Any
|
|
6
|
+
from typing import Callable, Dict, List, Any, Optional
|
|
7
7
|
|
|
8
8
|
from .util.DrawingUtils import DrawingUtils
|
|
9
9
|
from .streams.VideoStreamManager import VideoStreamManager
|
|
10
10
|
from .streams.StreamSyncThread import StreamSyncThread
|
|
11
11
|
from .pipeline.PipelineSyncThread import PipelineSyncThread
|
|
12
12
|
from .database.DatabaseManager import DatabaseManager
|
|
13
|
+
from .services.VideoSharingDaemonManager import get_daemon_manager
|
|
14
|
+
from .callbacks import (
|
|
15
|
+
DetectionCallbackManager,
|
|
16
|
+
DetectionType,
|
|
17
|
+
CallbackTrigger,
|
|
18
|
+
DetectionData
|
|
19
|
+
)
|
|
13
20
|
# Import models to ensure they are registered with SQLAlchemy Base registry
|
|
14
21
|
from . import models
|
|
15
22
|
import cv2
|
|
@@ -18,18 +25,14 @@ import cv2
|
|
|
18
25
|
class CoreService:
|
|
19
26
|
"""Service class for running the Nedo Vision Core processing."""
|
|
20
27
|
|
|
21
|
-
|
|
22
|
-
_detection_callbacks: Dict[str, List[Callable]] = {
|
|
23
|
-
'ppe_detection': [],
|
|
24
|
-
'area_violation': [],
|
|
25
|
-
'general_detection': []
|
|
26
|
-
}
|
|
28
|
+
_callback_manager: Optional[DetectionCallbackManager] = None
|
|
27
29
|
|
|
28
30
|
def __init__(self,
|
|
29
31
|
drawing_assets_path: str = None,
|
|
30
32
|
log_level: str = "INFO",
|
|
31
33
|
storage_path: str = "data",
|
|
32
|
-
rtmp_server: str = "rtmp://live.vision.sindika.co.id:1935/live"
|
|
34
|
+
rtmp_server: str = "rtmp://live.vision.sindika.co.id:1935/live",
|
|
35
|
+
enable_video_sharing_daemon: bool = True):
|
|
33
36
|
"""
|
|
34
37
|
Initialize the Core Service.
|
|
35
38
|
|
|
@@ -38,11 +41,17 @@ class CoreService:
|
|
|
38
41
|
log_level: Logging level (DEBUG, INFO, WARNING, ERROR)
|
|
39
42
|
storage_path: Storage path for databases and files (default: data)
|
|
40
43
|
rtmp_server: RTMP server URL for video streaming (default: rtmp://localhost:1935/live)
|
|
44
|
+
enable_video_sharing_daemon: Enable automatic video sharing daemon management (default: True)
|
|
41
45
|
"""
|
|
42
46
|
self.running = True
|
|
43
47
|
self.video_manager = None
|
|
44
48
|
self.stream_sync_thread = None
|
|
45
49
|
self.pipeline_sync_thread = None
|
|
50
|
+
self.enable_video_sharing_daemon = enable_video_sharing_daemon
|
|
51
|
+
|
|
52
|
+
# Initialize callback manager if not already done
|
|
53
|
+
if CoreService._callback_manager is None:
|
|
54
|
+
CoreService._callback_manager = DetectionCallbackManager()
|
|
46
55
|
|
|
47
56
|
# Store configuration parameters
|
|
48
57
|
self.storage_path = storage_path
|
|
@@ -65,73 +74,106 @@ class CoreService:
|
|
|
65
74
|
signal.signal(signal.SIGINT, self._signal_handler)
|
|
66
75
|
signal.signal(signal.SIGTERM, self._signal_handler)
|
|
67
76
|
|
|
77
|
+
# Detection Callback System Methods
|
|
78
|
+
|
|
68
79
|
@classmethod
|
|
69
|
-
def
|
|
80
|
+
def register_callback(cls,
|
|
81
|
+
name: str,
|
|
82
|
+
callback: Callable[[DetectionData], None],
|
|
83
|
+
trigger: CallbackTrigger,
|
|
84
|
+
detection_types: List[DetectionType],
|
|
85
|
+
interval_seconds: Optional[int] = None) -> None:
|
|
70
86
|
"""
|
|
71
|
-
Register a
|
|
87
|
+
Register a detection callback.
|
|
72
88
|
|
|
73
89
|
Args:
|
|
74
|
-
|
|
75
|
-
callback: Function to call when detection occurs
|
|
90
|
+
name: Unique name for the callback
|
|
91
|
+
callback: Function to call when detection occurs
|
|
92
|
+
trigger: When to trigger (ON_NEW_DETECTION or ON_VIOLATION_INTERVAL)
|
|
93
|
+
detection_types: Types of detections to listen for
|
|
94
|
+
interval_seconds: For interval callbacks, how often to call (in seconds)
|
|
95
|
+
|
|
96
|
+
Example:
|
|
97
|
+
# Immediate callback for PPE violations
|
|
98
|
+
CoreService.register_callback(
|
|
99
|
+
"ppe_alert",
|
|
100
|
+
my_ppe_callback,
|
|
101
|
+
CallbackTrigger.ON_NEW_DETECTION,
|
|
102
|
+
[DetectionType.PPE_DETECTION]
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
# Interval callback for area violations every 30 seconds
|
|
106
|
+
CoreService.register_callback(
|
|
107
|
+
"area_summary",
|
|
108
|
+
my_area_summary_callback,
|
|
109
|
+
CallbackTrigger.ON_VIOLATION_INTERVAL,
|
|
110
|
+
[DetectionType.AREA_VIOLATION],
|
|
111
|
+
interval_seconds=30
|
|
112
|
+
)
|
|
76
113
|
"""
|
|
77
|
-
if
|
|
78
|
-
cls.
|
|
114
|
+
if cls._callback_manager is None:
|
|
115
|
+
cls._callback_manager = DetectionCallbackManager()
|
|
79
116
|
|
|
80
|
-
cls.
|
|
81
|
-
|
|
82
|
-
|
|
117
|
+
cls._callback_manager.register_callback(
|
|
118
|
+
name=name,
|
|
119
|
+
callback=callback,
|
|
120
|
+
trigger=trigger,
|
|
121
|
+
detection_types=detection_types,
|
|
122
|
+
interval_seconds=interval_seconds
|
|
123
|
+
)
|
|
124
|
+
|
|
83
125
|
@classmethod
|
|
84
|
-
def
|
|
126
|
+
def unregister_callback(cls, name: str) -> bool:
|
|
85
127
|
"""
|
|
86
|
-
Unregister a callback
|
|
128
|
+
Unregister a callback.
|
|
87
129
|
|
|
88
130
|
Args:
|
|
89
|
-
|
|
90
|
-
|
|
131
|
+
name: Name of the callback to remove
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
True if callback was found and removed, False otherwise
|
|
91
135
|
"""
|
|
92
|
-
if
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
136
|
+
if cls._callback_manager is None:
|
|
137
|
+
return False
|
|
138
|
+
|
|
139
|
+
return cls._callback_manager.unregister_callback(name)
|
|
140
|
+
|
|
96
141
|
@classmethod
|
|
97
|
-
def
|
|
142
|
+
def trigger_detection(cls, detection_data: DetectionData) -> None:
|
|
98
143
|
"""
|
|
99
|
-
Trigger
|
|
144
|
+
Trigger detection callbacks.
|
|
100
145
|
|
|
101
146
|
Args:
|
|
102
|
-
|
|
103
|
-
detection_data: Dict containing detection information
|
|
147
|
+
detection_data: The detection data to process
|
|
104
148
|
"""
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
# Call specific callbacks
|
|
109
|
-
for callback in callbacks:
|
|
110
|
-
try:
|
|
111
|
-
callback(detection_data)
|
|
112
|
-
except Exception as e:
|
|
113
|
-
logging.error(f"❌ Error in {detection_type} callback {callback.__name__}: {e}")
|
|
114
|
-
|
|
115
|
-
# Call general callbacks
|
|
116
|
-
for callback in general_callbacks:
|
|
117
|
-
try:
|
|
118
|
-
callback(detection_data)
|
|
119
|
-
except Exception as e:
|
|
120
|
-
logging.error(f"❌ Error in general detection callback {callback.__name__}: {e}")
|
|
121
|
-
|
|
149
|
+
if cls._callback_manager is not None:
|
|
150
|
+
cls._callback_manager.trigger_detection(detection_data)
|
|
151
|
+
|
|
122
152
|
@classmethod
|
|
123
|
-
def
|
|
153
|
+
def get_callback_stats(cls) -> Dict[str, Any]:
|
|
124
154
|
"""
|
|
125
|
-
|
|
155
|
+
Get statistics about registered callbacks and recent activity.
|
|
126
156
|
|
|
127
157
|
Returns:
|
|
128
|
-
|
|
158
|
+
Dictionary with callback statistics
|
|
129
159
|
"""
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
160
|
+
if cls._callback_manager is None:
|
|
161
|
+
return {"error": "Callback manager not initialized"}
|
|
162
|
+
|
|
163
|
+
return cls._callback_manager.get_callback_stats()
|
|
164
|
+
|
|
165
|
+
@classmethod
|
|
166
|
+
def list_callbacks(cls) -> Dict[str, Dict[str, Any]]:
|
|
167
|
+
"""
|
|
168
|
+
List all callbacks with their configurations.
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
Dictionary mapping callback names to their configurations
|
|
172
|
+
"""
|
|
173
|
+
if cls._callback_manager is None:
|
|
174
|
+
return {}
|
|
175
|
+
|
|
176
|
+
return cls._callback_manager.list_callbacks()
|
|
135
177
|
|
|
136
178
|
def _setup_environment(self):
|
|
137
179
|
"""Set up environment variables for components that still require them (like RTMPStreamer)."""
|
|
@@ -164,6 +206,16 @@ class CoreService:
|
|
|
164
206
|
# Set up environment variables for internal components that still need them
|
|
165
207
|
self._setup_environment()
|
|
166
208
|
|
|
209
|
+
# Initialize video sharing daemon manager if enabled
|
|
210
|
+
if self.enable_video_sharing_daemon:
|
|
211
|
+
daemon_manager = get_daemon_manager()
|
|
212
|
+
daemon_manager.enable_auto_start(True)
|
|
213
|
+
logging.info("🔗 Video sharing daemon auto-start enabled")
|
|
214
|
+
else:
|
|
215
|
+
daemon_manager = get_daemon_manager()
|
|
216
|
+
daemon_manager.enable_auto_start(False)
|
|
217
|
+
logging.info("⚠️ Video sharing daemon auto-start disabled")
|
|
218
|
+
|
|
167
219
|
# Initialize Database with storage path
|
|
168
220
|
DatabaseManager.init_databases(storage_path=self.storage_path)
|
|
169
221
|
|
|
@@ -225,6 +277,20 @@ class CoreService:
|
|
|
225
277
|
|
|
226
278
|
if self.video_manager:
|
|
227
279
|
self.video_manager.stop_all()
|
|
280
|
+
|
|
281
|
+
# Stop video sharing daemons if they were auto-started
|
|
282
|
+
if self.enable_video_sharing_daemon:
|
|
283
|
+
try:
|
|
284
|
+
daemon_manager = get_daemon_manager()
|
|
285
|
+
daemon_manager.stop_all_daemons()
|
|
286
|
+
logging.info("🔗 Video sharing daemons stopped")
|
|
287
|
+
except Exception as e:
|
|
288
|
+
logging.warning(f"⚠️ Error stopping video sharing daemons: {e}")
|
|
289
|
+
|
|
290
|
+
# Stop callback manager
|
|
291
|
+
if CoreService._callback_manager:
|
|
292
|
+
CoreService._callback_manager.stop()
|
|
293
|
+
CoreService._callback_manager = None
|
|
228
294
|
|
|
229
295
|
# Final cleanup
|
|
230
296
|
cv2.destroyAllWindows()
|
|
@@ -234,4 +300,4 @@ class CoreService:
|
|
|
234
300
|
except Exception as e:
|
|
235
301
|
logging.error(f"Error during shutdown: {e}")
|
|
236
302
|
finally:
|
|
237
|
-
logging.info("✅ Nedo Vision Core shutdown complete.")
|
|
303
|
+
logging.info("✅ Nedo Vision Core shutdown complete.")
|
|
@@ -99,11 +99,11 @@ class DatabaseManager:
|
|
|
99
99
|
|
|
100
100
|
# Set storage paths - prioritize parameter over environment variables
|
|
101
101
|
if storage_path:
|
|
102
|
-
DatabaseManager.STORAGE_PATH = Path(storage_path)
|
|
102
|
+
DatabaseManager.STORAGE_PATH = Path(storage_path).resolve()
|
|
103
103
|
else:
|
|
104
104
|
# Fallback to environment variables for backward compatibility
|
|
105
105
|
DatabaseManager._load_env_file()
|
|
106
|
-
DatabaseManager.STORAGE_PATH = Path(os.environ.get("STORAGE_PATH", "data"))
|
|
106
|
+
DatabaseManager.STORAGE_PATH = Path(os.environ.get("STORAGE_PATH", "data")).resolve()
|
|
107
107
|
|
|
108
108
|
DatabaseManager.STORAGE_PATHS = {
|
|
109
109
|
"db": DatabaseManager.STORAGE_PATH / "sqlite",
|
|
@@ -10,12 +10,13 @@ class BaseDetector(ABC):
|
|
|
10
10
|
pass
|
|
11
11
|
|
|
12
12
|
@abstractmethod
|
|
13
|
-
def detect_objects(self, frame, confidence_threshold=0.7):
|
|
13
|
+
def detect_objects(self, frame, confidence_threshold=0.7, class_thresholds=None):
|
|
14
14
|
"""
|
|
15
15
|
Detect objects in the input frame.
|
|
16
16
|
Args:
|
|
17
17
|
frame: Image/frame (numpy array)
|
|
18
18
|
confidence_threshold: Minimum confidence threshold for detections (optional)
|
|
19
|
+
class_thresholds: Dict mapping class names to specific confidence thresholds (optional)
|
|
19
20
|
Returns:
|
|
20
21
|
List of detections: [{"label": str, "confidence": float, "bbox": [x1, y1, x2, y2]}, ...]
|
|
21
22
|
"""
|
|
@@ -77,7 +77,7 @@ class DetectionManager:
|
|
|
77
77
|
self.detector = None
|
|
78
78
|
self.model_metadata = None
|
|
79
79
|
|
|
80
|
-
def detect_objects(self, frame, confidence_threshold=0.7):
|
|
80
|
+
def detect_objects(self, frame, confidence_threshold=0.7, class_thresholds=None):
|
|
81
81
|
if not self.detector:
|
|
82
82
|
return []
|
|
83
|
-
return self.detector.detect_objects(frame, confidence_threshold)
|
|
83
|
+
return self.detector.detect_objects(frame, confidence_threshold, class_thresholds)
|
|
@@ -8,25 +8,30 @@ except ImportError:
|
|
|
8
8
|
RFDETRBase = None
|
|
9
9
|
|
|
10
10
|
from ..database.DatabaseManager import DatabaseManager
|
|
11
|
+
from ..models.ai_model import AIModelEntity
|
|
11
12
|
from .BaseDetector import BaseDetector
|
|
12
13
|
|
|
13
14
|
logging.getLogger("ultralytics").setLevel(logging.WARNING)
|
|
14
15
|
|
|
15
16
|
class RFDETRDetector(BaseDetector):
|
|
16
|
-
def __init__(self, model):
|
|
17
|
+
def __init__(self, model: AIModelEntity):
|
|
17
18
|
if not RFDETR_AVAILABLE:
|
|
18
19
|
raise ImportError(
|
|
19
20
|
"RF-DETR is required but not installed. Install it manually with:\n"
|
|
20
21
|
"pip install rfdetr @ git+https://github.com/roboflow/rf-detr.git@1e63dbad402eea10f110e86013361d6b02ee0c09\n"
|
|
21
22
|
"See the documentation for more details."
|
|
22
23
|
)
|
|
24
|
+
if not isinstance(model, AIModelEntity):
|
|
25
|
+
raise TypeError("model must be an instance of AIModelEntity")
|
|
23
26
|
self.model = None
|
|
24
27
|
self.metadata = None
|
|
25
28
|
|
|
26
29
|
if model:
|
|
27
30
|
self.load_model(model)
|
|
28
31
|
|
|
29
|
-
def load_model(self, model):
|
|
32
|
+
def load_model(self, model: AIModelEntity):
|
|
33
|
+
if not isinstance(model, AIModelEntity):
|
|
34
|
+
raise TypeError("model must be an instance of AIModelEntity")
|
|
30
35
|
self.metadata = model
|
|
31
36
|
path = DatabaseManager.STORAGE_PATHS["models"] / model.file
|
|
32
37
|
|
|
@@ -44,17 +49,30 @@ class RFDETRDetector(BaseDetector):
|
|
|
44
49
|
self.model = None
|
|
45
50
|
return False
|
|
46
51
|
|
|
47
|
-
def detect_objects(self, frame, confidence_threshold=0.7):
|
|
52
|
+
def detect_objects(self, frame, confidence_threshold=0.7, class_thresholds=None):
|
|
48
53
|
if self.model is None:
|
|
49
54
|
return []
|
|
50
55
|
|
|
51
56
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
52
57
|
results = self.model.predict(frame_rgb, confidence_threshold)
|
|
53
58
|
|
|
59
|
+
class_names = self.metadata.get_classes() if hasattr(self.metadata, "get_classes") else None
|
|
60
|
+
if not class_names:
|
|
61
|
+
class_names = getattr(self.model, "class_names", None)
|
|
62
|
+
|
|
54
63
|
detections = []
|
|
55
|
-
for class_id, conf, xyxy in zip(results.class_id, results.confidence, results.xyxy):
|
|
64
|
+
for class_id, conf, xyxy in zip(results.class_id, results.confidence, results.xyxy):
|
|
65
|
+
label = class_names[class_id - 1] if class_names else str(class_id)
|
|
66
|
+
|
|
67
|
+
threshold = confidence_threshold
|
|
68
|
+
if class_thresholds and label in class_thresholds:
|
|
69
|
+
threshold = class_thresholds[label]
|
|
70
|
+
|
|
71
|
+
if conf < threshold:
|
|
72
|
+
continue
|
|
73
|
+
|
|
56
74
|
detections.append({
|
|
57
|
-
"label":
|
|
75
|
+
"label": label,
|
|
58
76
|
"confidence": conf,
|
|
59
77
|
"bbox": xyxy
|
|
60
78
|
})
|
|
@@ -3,18 +3,23 @@ import logging
|
|
|
3
3
|
from ultralytics import YOLO
|
|
4
4
|
from ..database.DatabaseManager import DatabaseManager
|
|
5
5
|
from .BaseDetector import BaseDetector
|
|
6
|
+
from ..models.ai_model import AIModelEntity
|
|
6
7
|
|
|
7
8
|
logging.getLogger("ultralytics").setLevel(logging.WARNING)
|
|
8
9
|
|
|
9
10
|
class YOLODetector(BaseDetector):
|
|
10
|
-
def __init__(self, model):
|
|
11
|
+
def __init__(self, model: AIModelEntity):
|
|
12
|
+
if not isinstance(model, AIModelEntity):
|
|
13
|
+
raise TypeError("model must be an instance of AIModelEntity")
|
|
11
14
|
self.model = None
|
|
12
15
|
self.metadata = None
|
|
13
16
|
|
|
14
17
|
if model:
|
|
15
18
|
self.load_model(model)
|
|
16
19
|
|
|
17
|
-
def load_model(self, model):
|
|
20
|
+
def load_model(self, model: AIModelEntity):
|
|
21
|
+
if not isinstance(model, AIModelEntity):
|
|
22
|
+
raise TypeError("model must be an instance of AIModelEntity")
|
|
18
23
|
self.metadata = model
|
|
19
24
|
path = DatabaseManager.STORAGE_PATHS["models"] / model.file
|
|
20
25
|
|
|
@@ -32,20 +37,28 @@ class YOLODetector(BaseDetector):
|
|
|
32
37
|
self.model = None
|
|
33
38
|
return False
|
|
34
39
|
|
|
35
|
-
def detect_objects(self, frame,
|
|
40
|
+
def detect_objects(self, frame, confidence_threshold=0.7, class_thresholds=None):
|
|
36
41
|
if self.model is None:
|
|
37
42
|
return []
|
|
38
43
|
|
|
39
44
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
40
45
|
results = self.model(frame_rgb)
|
|
41
46
|
|
|
47
|
+
class_names = self.metadata.get_classes()
|
|
48
|
+
if not class_names:
|
|
49
|
+
class_names = self.model.names
|
|
50
|
+
|
|
42
51
|
detections = []
|
|
43
52
|
for box in results[0].boxes:
|
|
44
53
|
class_id = int(box.cls)
|
|
45
|
-
label =
|
|
54
|
+
label = class_names[class_id]
|
|
46
55
|
confidence = float(box.conf)
|
|
47
56
|
|
|
48
|
-
|
|
57
|
+
threshold = confidence_threshold
|
|
58
|
+
if class_thresholds and label in class_thresholds:
|
|
59
|
+
threshold = class_thresholds[label]
|
|
60
|
+
|
|
61
|
+
if confidence < threshold:
|
|
49
62
|
continue
|
|
50
63
|
|
|
51
64
|
detections.append({
|
|
@@ -6,6 +6,7 @@ from .DetectionProcessor import DetectionProcessor
|
|
|
6
6
|
from ...pipeline.PipelineConfigManager import PipelineConfigManager
|
|
7
7
|
from ...repositories.RestrictedAreaRepository import RestrictedAreaRepository
|
|
8
8
|
from ...util.PersonRestrictedAreaMatcher import PersonRestrictedAreaMatcher
|
|
9
|
+
from ...callbacks import DetectionType, DetectionAttribute, BoundingBox, DetectionData
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
class HumanDetectionProcessor(DetectionProcessor):
|
|
@@ -16,16 +17,33 @@ class HumanDetectionProcessor(DetectionProcessor):
|
|
|
16
17
|
def __init__(self):
|
|
17
18
|
self.repository = RestrictedAreaRepository()
|
|
18
19
|
self.restricted_areas = []
|
|
20
|
+
self.main_class_threshold = 0.7
|
|
21
|
+
self.main_class = "person" # Default fallback
|
|
19
22
|
|
|
20
|
-
def update(self, config_manager: PipelineConfigManager):
|
|
21
|
-
config = config_manager.get_feature_config(self.code,
|
|
23
|
+
def update(self, config_manager: PipelineConfigManager, ai_model=None):
|
|
24
|
+
config = config_manager.get_feature_config(self.code, {})
|
|
22
25
|
area_list = config.get("restrictedArea", [])
|
|
23
26
|
self.restricted_areas = [
|
|
24
27
|
[(p["x"], p["y"]) for p in area] for area in area_list
|
|
25
28
|
]
|
|
29
|
+
|
|
30
|
+
# Update main class threshold
|
|
31
|
+
self.main_class_threshold = config.get("minimumDetectionConfidence", 0.7)
|
|
32
|
+
|
|
33
|
+
# Update main class from AI model
|
|
34
|
+
if ai_model and ai_model.get_main_class():
|
|
35
|
+
self.main_class = ai_model.get_main_class()
|
|
36
|
+
else:
|
|
37
|
+
self.main_class = "person" # Default fallback
|
|
38
|
+
|
|
39
|
+
def get_main_class_threshold(self, ai_model=None):
|
|
40
|
+
"""Get the confidence threshold for the main class (person)"""
|
|
41
|
+
if ai_model and ai_model.get_main_class():
|
|
42
|
+
return self.main_class_threshold
|
|
43
|
+
return None
|
|
26
44
|
|
|
27
45
|
def process(self, detections: List[Dict[str, Any]], dimension: Tuple[int, int]) -> List[Dict[str, Any]]:
|
|
28
|
-
persons = [d for d in detections if d["label"] ==
|
|
46
|
+
persons = [d for d in detections if d["label"] == self.main_class]
|
|
29
47
|
|
|
30
48
|
height, width = dimension
|
|
31
49
|
area_polygons = []
|
|
@@ -45,3 +63,39 @@ class HumanDetectionProcessor(DetectionProcessor):
|
|
|
45
63
|
self.repository.save_area_violation(
|
|
46
64
|
pipeline_id, worker_source_id, frame_counter, tracked_objects, frame, frame_drawer
|
|
47
65
|
)
|
|
66
|
+
|
|
67
|
+
def get_multi_instance_classes(self):
|
|
68
|
+
"""Human detection doesn't have multi-instance classes"""
|
|
69
|
+
return []
|
|
70
|
+
|
|
71
|
+
@staticmethod
|
|
72
|
+
def create_detection_data(pipeline_id: str, worker_source_id: str, person_id: str,
|
|
73
|
+
detection_id: str, tracked_obj: Dict[str, Any],
|
|
74
|
+
image_path: str = "", image_tile_path: str = "",
|
|
75
|
+
frame_id: int = 0) -> DetectionData:
|
|
76
|
+
"""Create DetectionData from area violation data."""
|
|
77
|
+
bbox = BoundingBox.from_list(tracked_obj["bbox"])
|
|
78
|
+
|
|
79
|
+
attributes = []
|
|
80
|
+
for attr in tracked_obj.get("attributes", []):
|
|
81
|
+
# Area violations are always violations
|
|
82
|
+
attributes.append(DetectionAttribute(
|
|
83
|
+
label=attr["label"],
|
|
84
|
+
confidence=attr.get("confidence", 1.0),
|
|
85
|
+
count=attr.get("count", 0),
|
|
86
|
+
is_violation=True
|
|
87
|
+
))
|
|
88
|
+
|
|
89
|
+
return DetectionData(
|
|
90
|
+
detection_type=DetectionType.AREA_VIOLATION,
|
|
91
|
+
detection_id=detection_id,
|
|
92
|
+
person_id=person_id,
|
|
93
|
+
pipeline_id=pipeline_id,
|
|
94
|
+
worker_source_id=worker_source_id,
|
|
95
|
+
confidence_score=tracked_obj.get("confidence", 1.0),
|
|
96
|
+
bbox=bbox,
|
|
97
|
+
attributes=attributes,
|
|
98
|
+
image_path=image_path,
|
|
99
|
+
image_tile_path=image_tile_path,
|
|
100
|
+
frame_id=frame_id
|
|
101
|
+
)
|