matrice 1.0.99146__py3-none-any.whl → 1.0.99148__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- matrice/deploy/utils/post_processing/__init__.py +6 -0
- matrice/deploy/utils/post_processing/config.py +2 -0
- matrice/deploy/utils/post_processing/core/config.py +30 -0
- matrice/deploy/utils/post_processing/processor.py +4 -2
- matrice/deploy/utils/post_processing/usecases/__init__.py +3 -0
- matrice/deploy/utils/post_processing/usecases/fire_detection.py +472 -473
- matrice/deploy/utils/post_processing/usecases/smoker_detection.py +833 -0
- {matrice-1.0.99146.dist-info → matrice-1.0.99148.dist-info}/METADATA +1 -1
- {matrice-1.0.99146.dist-info → matrice-1.0.99148.dist-info}/RECORD +12 -11
- {matrice-1.0.99146.dist-info → matrice-1.0.99148.dist-info}/WHEEL +0 -0
- {matrice-1.0.99146.dist-info → matrice-1.0.99148.dist-info}/licenses/LICENSE.txt +0 -0
- {matrice-1.0.99146.dist-info → matrice-1.0.99148.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,833 @@
|
|
1
|
+
from typing import Any, Dict, List, Optional
|
2
|
+
from dataclasses import asdict
|
3
|
+
import time
|
4
|
+
from datetime import datetime, timezone
|
5
|
+
|
6
|
+
from ..core.base import BaseProcessor, ProcessingContext, ProcessingResult, ConfigProtocol, ResultFormat
|
7
|
+
from ..utils import (
|
8
|
+
filter_by_confidence,
|
9
|
+
filter_by_categories,
|
10
|
+
apply_category_mapping,
|
11
|
+
count_objects_by_category,
|
12
|
+
count_objects_in_zones,
|
13
|
+
calculate_counting_summary,
|
14
|
+
match_results_structure,
|
15
|
+
bbox_smoothing,
|
16
|
+
BBoxSmoothingConfig,
|
17
|
+
BBoxSmoothingTracker
|
18
|
+
)
|
19
|
+
from dataclasses import dataclass, field
|
20
|
+
from ..core.config import BaseConfig, AlertConfig, ZoneConfig
|
21
|
+
|
22
|
+
|
23
|
+
@dataclass
|
24
|
+
class SmokerDetectionConfig(BaseConfig):
|
25
|
+
"""Configuration for Smoker detection use case."""
|
26
|
+
# Smoothing configuration
|
27
|
+
enable_smoothing: bool = True
|
28
|
+
smoothing_algorithm: str = "observability" # "window" or "observability"
|
29
|
+
smoothing_window_size: int = 20
|
30
|
+
smoothing_cooldown_frames: int = 5
|
31
|
+
smoothing_confidence_range_factor: float = 0.5
|
32
|
+
|
33
|
+
#confidence thresholds
|
34
|
+
confidence_threshold: float = 0.6
|
35
|
+
|
36
|
+
usecase_categories: List[str] = field(
|
37
|
+
default_factory=lambda: ['Cigarette', 'Person', 'Smoke', 'Vape']
|
38
|
+
)
|
39
|
+
|
40
|
+
target_categories: List[str] = field(
|
41
|
+
default_factory=lambda: ['Cigarette', 'Person', 'Smoke', 'Vape']
|
42
|
+
)
|
43
|
+
|
44
|
+
alert_config: Optional[AlertConfig] = None
|
45
|
+
|
46
|
+
index_to_category: Optional[Dict[int, str]] = field(
|
47
|
+
default_factory=lambda: {
|
48
|
+
0: 'Cigarette',
|
49
|
+
1: 'Person',
|
50
|
+
2: 'Smoke',
|
51
|
+
3: 'Vape'
|
52
|
+
}
|
53
|
+
)
|
54
|
+
|
55
|
+
|
56
|
+
class SmokerDetectionUseCase(BaseProcessor):
|
57
|
+
# Human-friendly display names for categories
|
58
|
+
CATEGORY_DISPLAY = {
|
59
|
+
"Cigarette": 'Cigarette',
|
60
|
+
"Person": 'Person',
|
61
|
+
"Smoke": 'Smoke',
|
62
|
+
"Vape": 'Vape'
|
63
|
+
}
|
64
|
+
|
65
|
+
|
66
|
+
def __init__(self):
|
67
|
+
super().__init__("smoker_detection")
|
68
|
+
self.category = "general"
|
69
|
+
|
70
|
+
self.CASE_TYPE: Optional[str] = 'smoker_detection'
|
71
|
+
self.CASE_VERSION: Optional[str] = '1.2'
|
72
|
+
# List of categories to track
|
73
|
+
self.target_categories = ['Cigarette', 'Person', 'Smoke', 'Vape']
|
74
|
+
|
75
|
+
|
76
|
+
# Initialize smoothing tracker
|
77
|
+
self.smoothing_tracker = None
|
78
|
+
|
79
|
+
# Initialize advanced tracker (will be created on first use)
|
80
|
+
self.tracker = None
|
81
|
+
# Initialize tracking state variables
|
82
|
+
self._total_frame_counter = 0
|
83
|
+
self._global_frame_offset = 0
|
84
|
+
|
85
|
+
# Track start time for "TOTAL SINCE" calculation
|
86
|
+
self._tracking_start_time = None
|
87
|
+
|
88
|
+
self._track_aliases: Dict[Any, Any] = {}
|
89
|
+
self._canonical_tracks: Dict[Any, Dict[str, Any]] = {}
|
90
|
+
# Tunable parameters – adjust if necessary for specific scenarios
|
91
|
+
self._track_merge_iou_threshold: float = 0.05 # IoU ≥ 0.05 →
|
92
|
+
self._track_merge_time_window: float = 7.0 # seconds within which to merge
|
93
|
+
|
94
|
+
self._ascending_alert_list: List[int] = []
|
95
|
+
self.current_incident_end_timestamp: str = "N/A"
|
96
|
+
|
97
|
+
|
98
|
+
def process(self, data: Any, config: ConfigProtocol, context: Optional[ProcessingContext] = None,
|
99
|
+
stream_info: Optional[Dict[str, Any]] = None) -> ProcessingResult:
|
100
|
+
"""
|
101
|
+
Main entry point for post-processing.
|
102
|
+
Applies category mapping, smoothing, counting, alerting, and summary generation.
|
103
|
+
Returns a ProcessingResult with all relevant outputs.
|
104
|
+
"""
|
105
|
+
start_time = time.time()
|
106
|
+
# Ensure config is correct type
|
107
|
+
if not isinstance(config, SmokerDetectionConfig):
|
108
|
+
return self.create_error_result("Invalid config type", usecase=self.name, category=self.category,
|
109
|
+
context=context)
|
110
|
+
if context is None:
|
111
|
+
context = ProcessingContext()
|
112
|
+
|
113
|
+
# Detect input format and store in context
|
114
|
+
input_format = match_results_structure(data)
|
115
|
+
context.input_format = input_format
|
116
|
+
context.confidence_threshold = config.confidence_threshold
|
117
|
+
|
118
|
+
if config.confidence_threshold is not None:
|
119
|
+
processed_data = filter_by_confidence(data, config.confidence_threshold)
|
120
|
+
self.logger.debug(f"Applied confidence filtering with threshold {config.confidence_threshold}")
|
121
|
+
else:
|
122
|
+
processed_data = data
|
123
|
+
|
124
|
+
self.logger.debug(f"Did not apply confidence filtering with threshold since nothing was provided")
|
125
|
+
|
126
|
+
# Step 2: Apply category mapping if provided
|
127
|
+
if config.index_to_category:
|
128
|
+
processed_data = apply_category_mapping(processed_data, config.index_to_category)
|
129
|
+
self.logger.debug("Applied category mapping")
|
130
|
+
|
131
|
+
if config.target_categories:
|
132
|
+
processed_data = [d for d in processed_data if d.get('category') in self.target_categories]
|
133
|
+
self.logger.debug(f"Applied category filtering")
|
134
|
+
|
135
|
+
# Apply bbox smoothing if enabled
|
136
|
+
if config.enable_smoothing:
|
137
|
+
if self.smoothing_tracker is None:
|
138
|
+
smoothing_config = BBoxSmoothingConfig(
|
139
|
+
smoothing_algorithm=config.smoothing_algorithm,
|
140
|
+
window_size=config.smoothing_window_size,
|
141
|
+
cooldown_frames=config.smoothing_cooldown_frames,
|
142
|
+
confidence_threshold=config.confidence_threshold, # Use mask threshold as default
|
143
|
+
confidence_range_factor=config.smoothing_confidence_range_factor,
|
144
|
+
enable_smoothing=True
|
145
|
+
)
|
146
|
+
self.smoothing_tracker = BBoxSmoothingTracker(smoothing_config)
|
147
|
+
processed_data = bbox_smoothing(processed_data, self.smoothing_tracker.config, self.smoothing_tracker)
|
148
|
+
|
149
|
+
# Advanced tracking (BYTETracker-like)
|
150
|
+
try:
|
151
|
+
from ..advanced_tracker import AdvancedTracker
|
152
|
+
from ..advanced_tracker.config import TrackerConfig
|
153
|
+
|
154
|
+
# Create tracker instance if it doesn't exist (preserves state across frames)
|
155
|
+
if self.tracker is None:
|
156
|
+
# Configure tracker thresholds based on the use-case confidence threshold so that
|
157
|
+
# low-confidence detections (e.g. < 0.7) can still be initialised as tracks when
|
158
|
+
# the user passes a lower `confidence_threshold` in the post-processing config.
|
159
|
+
if config.confidence_threshold is not None:
|
160
|
+
tracker_config = TrackerConfig(
|
161
|
+
track_high_thresh=float(config.confidence_threshold),
|
162
|
+
# Allow even lower detections to participate in secondary association
|
163
|
+
track_low_thresh=max(0.05, float(config.confidence_threshold) / 2),
|
164
|
+
new_track_thresh=float(config.confidence_threshold)
|
165
|
+
)
|
166
|
+
else:
|
167
|
+
tracker_config = TrackerConfig()
|
168
|
+
self.tracker = AdvancedTracker(tracker_config)
|
169
|
+
self.logger.info(
|
170
|
+
"Initialized AdvancedTracker for Monitoring and tracking with thresholds: "
|
171
|
+
f"high={tracker_config.track_high_thresh}, "
|
172
|
+
f"low={tracker_config.track_low_thresh}, "
|
173
|
+
f"new={tracker_config.new_track_thresh}"
|
174
|
+
)
|
175
|
+
|
176
|
+
# The tracker expects the data in the same format as input
|
177
|
+
# It will add track_id and frame_id to each detection
|
178
|
+
processed_data = self.tracker.update(processed_data)
|
179
|
+
|
180
|
+
except Exception as e:
|
181
|
+
# If advanced tracker fails, fallback to unsmoothed detections
|
182
|
+
self.logger.warning(f"AdvancedTracker failed: {e}")
|
183
|
+
|
184
|
+
# Update tracking state for total count per label
|
185
|
+
self._update_tracking_state(processed_data)
|
186
|
+
|
187
|
+
# Update frame counter
|
188
|
+
self._total_frame_counter += 1
|
189
|
+
|
190
|
+
# Extract frame information from stream_info
|
191
|
+
frame_number = None
|
192
|
+
if stream_info:
|
193
|
+
input_settings = stream_info.get("input_settings", {})
|
194
|
+
start_frame = input_settings.get("start_frame")
|
195
|
+
end_frame = input_settings.get("end_frame")
|
196
|
+
# If start and end frame are the same, it's a single frame
|
197
|
+
if start_frame is not None and end_frame is not None and start_frame == end_frame:
|
198
|
+
frame_number = start_frame
|
199
|
+
|
200
|
+
# Compute summaries and alerts
|
201
|
+
general_counting_summary = calculate_counting_summary(data)
|
202
|
+
counting_summary = self._count_categories(processed_data, config)
|
203
|
+
# Add total unique counts after tracking using only local state
|
204
|
+
total_counts = self.get_total_counts()
|
205
|
+
counting_summary['total_counts'] = total_counts
|
206
|
+
|
207
|
+
alerts = self._check_alerts(counting_summary, frame_number, config)
|
208
|
+
predictions = self._extract_predictions(processed_data)
|
209
|
+
|
210
|
+
# Step: Generate structured incidents, tracking stats and business analytics with frame-based keys
|
211
|
+
incidents_list = self._generate_incidents(counting_summary, alerts, config, frame_number, stream_info)
|
212
|
+
tracking_stats_list = self._generate_tracking_stats(counting_summary, alerts, config, frame_number, stream_info)
|
213
|
+
business_analytics_list = self._generate_business_analytics(counting_summary, alerts, config, stream_info, is_empty=True)
|
214
|
+
summary_list = self._generate_summary(counting_summary, incidents_list, tracking_stats_list, business_analytics_list, alerts)
|
215
|
+
|
216
|
+
# Extract frame-based dictionaries from the lists
|
217
|
+
incidents = incidents_list[0] if incidents_list else {}
|
218
|
+
tracking_stats = tracking_stats_list[0] if tracking_stats_list else {}
|
219
|
+
business_analytics = business_analytics_list[0] if business_analytics_list else {}
|
220
|
+
summary = summary_list[0] if summary_list else {}
|
221
|
+
agg_summary = {str(frame_number): {
|
222
|
+
"incidents": incidents,
|
223
|
+
"tracking_stats": tracking_stats,
|
224
|
+
"business_analytics": business_analytics,
|
225
|
+
"alerts": alerts,
|
226
|
+
"human_text": summary}
|
227
|
+
}
|
228
|
+
|
229
|
+
|
230
|
+
context.mark_completed()
|
231
|
+
|
232
|
+
# Build result object following the new pattern
|
233
|
+
|
234
|
+
result = self.create_result(
|
235
|
+
data={"agg_summary": agg_summary},
|
236
|
+
usecase=self.name,
|
237
|
+
category=self.category,
|
238
|
+
context=context
|
239
|
+
)
|
240
|
+
|
241
|
+
return result
|
242
|
+
|
243
|
+
def _check_alerts(self, summary: dict, frame_number:Any, config: SmokerDetectionConfig) -> List[Dict]:
|
244
|
+
"""
|
245
|
+
Check if any alert thresholds are exceeded and return alert dicts.
|
246
|
+
"""
|
247
|
+
def get_trend(data, lookback=900, threshold=0.6):
|
248
|
+
'''
|
249
|
+
Determine if the trend is ascending or descending based on actual value progression.
|
250
|
+
Now works with values 0,1,2,3 (not just binary).
|
251
|
+
'''
|
252
|
+
window = data[-lookback:] if len(data) >= lookback else data
|
253
|
+
if len(window) < 2:
|
254
|
+
return True # not enough data to determine trend
|
255
|
+
increasing = 0
|
256
|
+
total = 0
|
257
|
+
for i in range(1, len(window)):
|
258
|
+
if window[i] >= window[i - 1]:
|
259
|
+
increasing += 1
|
260
|
+
total += 1
|
261
|
+
ratio = increasing / total
|
262
|
+
if ratio >= threshold:
|
263
|
+
return True
|
264
|
+
elif ratio <= (1 - threshold):
|
265
|
+
return False
|
266
|
+
|
267
|
+
frame_key = str(frame_number) if frame_number is not None else "current_frame"
|
268
|
+
alerts = []
|
269
|
+
total_detections = summary.get("total_count", 0) #CURRENT combined total count of all classes
|
270
|
+
total_counts_dict = summary.get("total_counts", {}) #TOTAL cumulative counts per class
|
271
|
+
cumulative_total = sum(total_counts_dict.values()) if total_counts_dict else 0 #TOTAL combined cumulative count
|
272
|
+
per_category_count = summary.get("per_category_count", {}) #CURRENT count per class
|
273
|
+
|
274
|
+
if not config.alert_config:
|
275
|
+
return alerts
|
276
|
+
|
277
|
+
total = summary.get("total_count", 0)
|
278
|
+
#self._ascending_alert_list
|
279
|
+
if hasattr(config.alert_config, 'count_thresholds') and config.alert_config.count_thresholds:
|
280
|
+
|
281
|
+
for category, threshold in config.alert_config.count_thresholds.items():
|
282
|
+
if category == "all" and total > threshold:
|
283
|
+
|
284
|
+
alerts.append({
|
285
|
+
"alert_type": getattr(config.alert_config, 'alert_type', ['Default']) if hasattr(config.alert_config, 'alert_type') else ['Default'],
|
286
|
+
"alert_id": "alert_"+category+'_'+frame_key,
|
287
|
+
"incident_category": self.CASE_TYPE,
|
288
|
+
"threshold_level": threshold,
|
289
|
+
"ascending": get_trend(self._ascending_alert_list, lookback=900, threshold=0.8),
|
290
|
+
"settings": {t: v for t, v in zip(getattr(config.alert_config, 'alert_type', ['Default']) if hasattr(config.alert_config, 'alert_type') else ['Default'],
|
291
|
+
getattr(config.alert_config, 'alert_value', ['JSON']) if hasattr(config.alert_config, 'alert_value') else ['JSON'])
|
292
|
+
}
|
293
|
+
})
|
294
|
+
elif category in summary.get("per_category_count", {}):
|
295
|
+
count = summary.get("per_category_count", {})[category]
|
296
|
+
if count > threshold: # Fixed logic: alert when EXCEEDING threshold
|
297
|
+
alerts.append({
|
298
|
+
"alert_type": getattr(config.alert_config, 'alert_type', ['Default']) if hasattr(config.alert_config, 'alert_type') else ['Default'],
|
299
|
+
"alert_id": "alert_"+category+'_'+frame_key,
|
300
|
+
"incident_category": self.CASE_TYPE,
|
301
|
+
"threshold_level": threshold,
|
302
|
+
"ascending": get_trend(self._ascending_alert_list, lookback=900, threshold=0.8),
|
303
|
+
"settings": {t: v for t, v in zip(getattr(config.alert_config, 'alert_type', ['Default']) if hasattr(config.alert_config, 'alert_type') else ['Default'],
|
304
|
+
getattr(config.alert_config, 'alert_value', ['JSON']) if hasattr(config.alert_config, 'alert_value') else ['JSON'])
|
305
|
+
}
|
306
|
+
})
|
307
|
+
else:
|
308
|
+
pass
|
309
|
+
return alerts
|
310
|
+
|
311
|
+
def _generate_incidents(self, counting_summary: Dict, alerts: List, config: SmokerDetectionConfig,
|
312
|
+
frame_number: Optional[int] = None, stream_info: Optional[Dict[str, Any]] = None) -> List[
|
313
|
+
Dict]:
|
314
|
+
"""Generate structured incidents for the output format with frame-based keys."""
|
315
|
+
|
316
|
+
incidents = []
|
317
|
+
total_detections = counting_summary.get("total_count", 0)
|
318
|
+
current_timestamp = self._get_current_timestamp_str(stream_info)
|
319
|
+
camera_info = self.get_camera_info_from_stream(stream_info)
|
320
|
+
|
321
|
+
self._ascending_alert_list = self._ascending_alert_list[-900:] if len(self._ascending_alert_list) > 900 else self._ascending_alert_list
|
322
|
+
|
323
|
+
if total_detections > 0:
|
324
|
+
# Determine event level based on thresholds
|
325
|
+
level = "low"
|
326
|
+
intensity = 5.0
|
327
|
+
start_timestamp = self._get_start_timestamp_str(stream_info)
|
328
|
+
if start_timestamp and self.current_incident_end_timestamp=='N/A':
|
329
|
+
self.current_incident_end_timestamp = 'Incident still active'
|
330
|
+
elif start_timestamp and self.current_incident_end_timestamp=='Incident still active':
|
331
|
+
if len(self._ascending_alert_list) >= 15 and sum(self._ascending_alert_list[-15:]) / 15 < 1.5:
|
332
|
+
self.current_incident_end_timestamp = current_timestamp
|
333
|
+
elif self.current_incident_end_timestamp!='Incident still active' and self.current_incident_end_timestamp!='N/A':
|
334
|
+
self.current_incident_end_timestamp = 'N/A'
|
335
|
+
|
336
|
+
if config.alert_config and config.alert_config.count_thresholds:
|
337
|
+
threshold = config.alert_config.count_thresholds.get("all", 15)
|
338
|
+
intensity = min(10.0, (total_detections / threshold) * 10)
|
339
|
+
|
340
|
+
if intensity >= 9:
|
341
|
+
level = "critical"
|
342
|
+
self._ascending_alert_list.append(3)
|
343
|
+
elif intensity >= 7:
|
344
|
+
level = "significant"
|
345
|
+
self._ascending_alert_list.append(2)
|
346
|
+
elif intensity >= 5:
|
347
|
+
level = "medium"
|
348
|
+
self._ascending_alert_list.append(1)
|
349
|
+
else:
|
350
|
+
level = "low"
|
351
|
+
self._ascending_alert_list.append(0)
|
352
|
+
else:
|
353
|
+
if total_detections > 30:
|
354
|
+
level = "critical"
|
355
|
+
intensity = 10.0
|
356
|
+
self._ascending_alert_list.append(3)
|
357
|
+
elif total_detections > 25:
|
358
|
+
level = "significant"
|
359
|
+
intensity = 9.0
|
360
|
+
self._ascending_alert_list.append(2)
|
361
|
+
elif total_detections > 15:
|
362
|
+
level = "medium"
|
363
|
+
intensity = 7.0
|
364
|
+
self._ascending_alert_list.append(1)
|
365
|
+
else:
|
366
|
+
level = "low"
|
367
|
+
intensity = min(10.0, total_detections / 3.0)
|
368
|
+
self._ascending_alert_list.append(0)
|
369
|
+
|
370
|
+
# Generate human text in new format
|
371
|
+
human_text_lines = [f"INCIDENTS DETECTED @ {current_timestamp}:"]
|
372
|
+
human_text_lines.append(f"\tSeverity Level: {(self.CASE_TYPE,level)}")
|
373
|
+
human_text = "\n".join(human_text_lines)
|
374
|
+
|
375
|
+
alert_settings=[]
|
376
|
+
if config.alert_config and hasattr(config.alert_config, 'alert_type'):
|
377
|
+
alert_settings.append({
|
378
|
+
"alert_type": getattr(config.alert_config, 'alert_type', ['Default']) if hasattr(config.alert_config, 'alert_type') else ['Default'],
|
379
|
+
"incident_category": self.CASE_TYPE,
|
380
|
+
"threshold_level": config.alert_config.count_thresholds if hasattr(config.alert_config, 'count_thresholds') else {},
|
381
|
+
"ascending": True,
|
382
|
+
"settings": {t: v for t, v in zip(getattr(config.alert_config, 'alert_type', ['Default']) if hasattr(config.alert_config, 'alert_type') else ['Default'],
|
383
|
+
getattr(config.alert_config, 'alert_value', ['JSON']) if hasattr(config.alert_config, 'alert_value') else ['JSON'])
|
384
|
+
}
|
385
|
+
})
|
386
|
+
|
387
|
+
event= self.create_incident(incident_id=self.CASE_TYPE+'_'+str(frame_number), incident_type=self.CASE_TYPE,
|
388
|
+
severity_level=level, human_text=human_text, camera_info=camera_info, alerts=alerts, alert_settings=alert_settings,
|
389
|
+
start_time=start_timestamp, end_time=self.current_incident_end_timestamp,
|
390
|
+
level_settings= {"low": 1, "medium": 3, "significant":4, "critical": 7})
|
391
|
+
incidents.append(event)
|
392
|
+
|
393
|
+
else:
|
394
|
+
self._ascending_alert_list.append(0)
|
395
|
+
incidents.append({})
|
396
|
+
|
397
|
+
return incidents
|
398
|
+
def _generate_tracking_stats(
|
399
|
+
self,
|
400
|
+
counting_summary: Dict,
|
401
|
+
alerts: List,
|
402
|
+
config: SmokerDetectionConfig,
|
403
|
+
frame_number: Optional[int] = None,
|
404
|
+
stream_info: Optional[Dict[str, Any]] = None
|
405
|
+
) -> List[Dict]:
|
406
|
+
"""Generate structured tracking stats matching eg.json format."""
|
407
|
+
camera_info = self.get_camera_info_from_stream(stream_info)
|
408
|
+
|
409
|
+
# frame_key = str(frame_number) if frame_number is not None else "current_frame"
|
410
|
+
# tracking_stats = [{frame_key: []}]
|
411
|
+
# frame_tracking_stats = tracking_stats[0][frame_key]
|
412
|
+
tracking_stats = []
|
413
|
+
|
414
|
+
total_detections = counting_summary.get("total_count", 0) #CURRENT total count of all classes
|
415
|
+
total_counts_dict = counting_summary.get("total_counts", {}) #TOTAL cumulative counts per class
|
416
|
+
cumulative_total = sum(total_counts_dict.values()) if total_counts_dict else 0 #TOTAL combined cumulative count
|
417
|
+
per_category_count = counting_summary.get("per_category_count", {}) #CURRENT count per class
|
418
|
+
|
419
|
+
current_timestamp = self._get_current_timestamp_str(stream_info, precision=False)
|
420
|
+
start_timestamp = self._get_start_timestamp_str(stream_info, precision=False)
|
421
|
+
|
422
|
+
# Create high precision timestamps for input_timestamp and reset_timestamp
|
423
|
+
high_precision_start_timestamp = self._get_current_timestamp_str(stream_info, precision=True)
|
424
|
+
high_precision_reset_timestamp = self._get_start_timestamp_str(stream_info, precision=True)
|
425
|
+
|
426
|
+
|
427
|
+
# Build total_counts array in expected format
|
428
|
+
total_counts = []
|
429
|
+
for cat, count in total_counts_dict.items():
|
430
|
+
if count > 0:
|
431
|
+
total_counts.append({
|
432
|
+
"category": cat,
|
433
|
+
"count": count
|
434
|
+
})
|
435
|
+
|
436
|
+
# Build current_counts array in expected format
|
437
|
+
current_counts = []
|
438
|
+
for cat, count in per_category_count.items():
|
439
|
+
if count > 0 or total_detections > 0: # Include even if 0 when there are detections
|
440
|
+
current_counts.append({
|
441
|
+
"category": cat,
|
442
|
+
"count": count
|
443
|
+
})
|
444
|
+
|
445
|
+
# Prepare detections without confidence scores (as per eg.json)
|
446
|
+
detections = []
|
447
|
+
for detection in counting_summary.get("detections", []):
|
448
|
+
bbox = detection.get("bounding_box", {})
|
449
|
+
category = detection.get("category", "person")
|
450
|
+
# Include segmentation if available (like in eg.json)
|
451
|
+
if detection.get("masks"):
|
452
|
+
segmentation= detection.get("masks", [])
|
453
|
+
detection_obj = self.create_detection_object(category, bbox, segmentation=segmentation)
|
454
|
+
elif detection.get("segmentation"):
|
455
|
+
segmentation= detection.get("segmentation")
|
456
|
+
detection_obj = self.create_detection_object(category, bbox, segmentation=segmentation)
|
457
|
+
elif detection.get("mask"):
|
458
|
+
segmentation= detection.get("mask")
|
459
|
+
detection_obj = self.create_detection_object(category, bbox, segmentation=segmentation)
|
460
|
+
else:
|
461
|
+
detection_obj = self.create_detection_object(category, bbox)
|
462
|
+
detections.append(detection_obj)
|
463
|
+
|
464
|
+
# Build alert_settings array in expected format
|
465
|
+
alert_settings = []
|
466
|
+
if config.alert_config and hasattr(config.alert_config, 'alert_type'):
|
467
|
+
alert_settings.append({
|
468
|
+
"alert_type": getattr(config.alert_config, 'alert_type', ['Default']) if hasattr(config.alert_config, 'alert_type') else ['Default'],
|
469
|
+
"incident_category": self.CASE_TYPE,
|
470
|
+
"threshold_level": config.alert_config.count_thresholds if hasattr(config.alert_config, 'count_thresholds') else {},
|
471
|
+
"ascending": True,
|
472
|
+
"settings": {t: v for t, v in zip(getattr(config.alert_config, 'alert_type', ['Default']) if hasattr(config.alert_config, 'alert_type') else ['Default'],
|
473
|
+
getattr(config.alert_config, 'alert_value', ['JSON']) if hasattr(config.alert_config, 'alert_value') else ['JSON'])
|
474
|
+
}
|
475
|
+
})
|
476
|
+
|
477
|
+
# Generate human_text in expected format
|
478
|
+
human_text_lines = [f"Tracking Statistics:"]
|
479
|
+
human_text_lines.append(f"CURRENT FRAME @ {current_timestamp}")
|
480
|
+
|
481
|
+
for cat, count in per_category_count.items():
|
482
|
+
human_text_lines.append(f"\t{cat}: {count}")
|
483
|
+
|
484
|
+
human_text_lines.append(f"TOTAL SINCE {start_timestamp}")
|
485
|
+
for cat, count in total_counts_dict.items():
|
486
|
+
if count > 0:
|
487
|
+
human_text_lines.append(f"\t{cat}: {count}")
|
488
|
+
|
489
|
+
if alerts:
|
490
|
+
for alert in alerts:
|
491
|
+
human_text_lines.append(f"Alerts: {alert.get('settings', {})} sent @ {current_timestamp}")
|
492
|
+
else:
|
493
|
+
human_text_lines.append("Alerts: None")
|
494
|
+
|
495
|
+
human_text = "\n".join(human_text_lines)
|
496
|
+
reset_settings=[
|
497
|
+
{
|
498
|
+
"interval_type": "daily",
|
499
|
+
"reset_time": {
|
500
|
+
"value": 9,
|
501
|
+
"time_unit": "hour"
|
502
|
+
}
|
503
|
+
}
|
504
|
+
]
|
505
|
+
|
506
|
+
tracking_stat=self.create_tracking_stats(total_counts=total_counts, current_counts=current_counts,
|
507
|
+
detections=detections, human_text=human_text, camera_info=camera_info, alerts=alerts, alert_settings=alert_settings,
|
508
|
+
reset_settings=reset_settings, start_time=high_precision_start_timestamp ,
|
509
|
+
reset_time=high_precision_reset_timestamp)
|
510
|
+
|
511
|
+
tracking_stats.append(tracking_stat)
|
512
|
+
return tracking_stats
|
513
|
+
|
514
|
+
def _generate_business_analytics(self, counting_summary: Dict, alerts:Any, config: SmokerDetectionConfig, stream_info: Optional[Dict[str, Any]] = None, is_empty=False) -> List[Dict]:
|
515
|
+
"""Generate standardized business analytics for the agg_summary structure."""
|
516
|
+
if is_empty:
|
517
|
+
return []
|
518
|
+
|
519
|
+
#-----IF YOUR USECASE NEEDS BUSINESS ANALYTICS, YOU CAN USE THIS FUNCTION------#
|
520
|
+
#camera_info = self.get_camera_info_from_stream(stream_info)
|
521
|
+
# business_analytics = self.create_business_analytics(nalysis_name, statistics,
|
522
|
+
# human_text, camera_info=camera_info, alerts=alerts, alert_settings=alert_settings,
|
523
|
+
# reset_settings)
|
524
|
+
# return business_analytics
|
525
|
+
|
526
|
+
def _generate_summary(self, summary: dict, incidents: List, tracking_stats: List, business_analytics: List, alerts: List) -> List[str]:
|
527
|
+
"""
|
528
|
+
Generate a human_text string for the tracking_stat, incident, business analytics and alerts.
|
529
|
+
"""
|
530
|
+
lines = {}
|
531
|
+
lines["Application Name"] = self.CASE_TYPE
|
532
|
+
lines["Application Version"] = self.CASE_VERSION
|
533
|
+
if len(incidents) > 0:
|
534
|
+
lines["Incidents:"]=f"\n\t{incidents[0].get('human_text', 'No incidents detected')}\n"
|
535
|
+
if len(tracking_stats) > 0:
|
536
|
+
lines["Tracking Statistics:"]=f"\t{tracking_stats[0].get('human_text', 'No tracking statistics detected')}\n"
|
537
|
+
if len(business_analytics) > 0:
|
538
|
+
lines["Business Analytics:"]=f"\t{business_analytics[0].get('human_text', 'No business analytics detected')}\n"
|
539
|
+
|
540
|
+
if len(incidents) == 0 and len(tracking_stats) == 0 and len(business_analytics) == 0:
|
541
|
+
lines["Summary"] = "No Summary Data"
|
542
|
+
|
543
|
+
return [lines]
|
544
|
+
|
545
|
+
def _get_track_ids_info(self, detections: list) -> Dict[str, Any]:
|
546
|
+
"""
|
547
|
+
Get detailed information about track IDs (per frame).
|
548
|
+
"""
|
549
|
+
# Collect all track_ids in this frame
|
550
|
+
frame_track_ids = set()
|
551
|
+
for det in detections:
|
552
|
+
tid = det.get('track_id')
|
553
|
+
if tid is not None:
|
554
|
+
frame_track_ids.add(tid)
|
555
|
+
# Use persistent total set for unique counting
|
556
|
+
total_track_ids = set()
|
557
|
+
for s in getattr(self, '_per_category_total_track_ids', {}).values():
|
558
|
+
total_track_ids.update(s)
|
559
|
+
return {
|
560
|
+
"total_count": len(total_track_ids),
|
561
|
+
"current_frame_count": len(frame_track_ids),
|
562
|
+
"total_unique_track_ids": len(total_track_ids),
|
563
|
+
"current_frame_track_ids": list(frame_track_ids),
|
564
|
+
"last_update_time": time.time(),
|
565
|
+
"total_frames_processed": getattr(self, '_total_frame_counter', 0)
|
566
|
+
}
|
567
|
+
|
568
|
+
def _update_tracking_state(self, detections: list):
|
569
|
+
"""
|
570
|
+
Track unique categories track_ids per category for total count after tracking.
|
571
|
+
Applies canonical ID merging to avoid duplicate counting when the underlying
|
572
|
+
tracker loses an object temporarily and assigns a new ID.
|
573
|
+
"""
|
574
|
+
# Lazily initialise storage dicts
|
575
|
+
if not hasattr(self, "_per_category_total_track_ids"):
|
576
|
+
self._per_category_total_track_ids = {cat: set() for cat in self.target_categories}
|
577
|
+
self._current_frame_track_ids = {cat: set() for cat in self.target_categories}
|
578
|
+
|
579
|
+
for det in detections:
|
580
|
+
cat = det.get("category")
|
581
|
+
raw_track_id = det.get("track_id")
|
582
|
+
if cat not in self.target_categories or raw_track_id is None:
|
583
|
+
continue
|
584
|
+
bbox = det.get("bounding_box", det.get("bbox"))
|
585
|
+
canonical_id = self._merge_or_register_track(raw_track_id, bbox)
|
586
|
+
# Propagate canonical ID back to detection so downstream logic uses it
|
587
|
+
det["track_id"] = canonical_id
|
588
|
+
|
589
|
+
self._per_category_total_track_ids.setdefault(cat, set()).add(canonical_id)
|
590
|
+
self._current_frame_track_ids[cat].add(canonical_id)
|
591
|
+
|
592
|
+
def get_total_counts(self):
|
593
|
+
"""
|
594
|
+
Return total unique track_id count for each category.
|
595
|
+
"""
|
596
|
+
return {cat: len(ids) for cat, ids in getattr(self, '_per_category_total_track_ids', {}).items()}
|
597
|
+
|
598
|
+
|
599
|
+
def _format_timestamp_for_stream(self, timestamp: float) -> str:
|
600
|
+
"""Format timestamp for streams (YYYY:MM:DD HH:MM:SS format)."""
|
601
|
+
dt = datetime.fromtimestamp(timestamp, tz=timezone.utc)
|
602
|
+
return dt.strftime('%Y:%m:%d %H:%M:%S')
|
603
|
+
|
604
|
+
def _format_timestamp_for_video(self, timestamp: float) -> str:
|
605
|
+
"""Format timestamp for video chunks (HH:MM:SS.ms format)."""
|
606
|
+
hours = int(timestamp // 3600)
|
607
|
+
minutes = int((timestamp % 3600) // 60)
|
608
|
+
seconds = round(float(timestamp % 60),2)
|
609
|
+
return f"{hours:02d}:{minutes:02d}:{seconds:.1f}"
|
610
|
+
|
611
|
+
def _get_current_timestamp_str(self, stream_info: Optional[Dict[str, Any]], precision=False, frame_id: Optional[str]=None) -> str:
|
612
|
+
"""Get formatted current timestamp based on stream type."""
|
613
|
+
if not stream_info:
|
614
|
+
return "00:00:00.00"
|
615
|
+
# is_video_chunk = stream_info.get("input_settings", {}).get("is_video_chunk", False)
|
616
|
+
if precision:
|
617
|
+
if stream_info.get("input_settings", {}).get("start_frame", "na") != "na":
|
618
|
+
if frame_id:
|
619
|
+
start_time = int(frame_id)/stream_info.get("input_settings", {}).get("original_fps", 30)
|
620
|
+
else:
|
621
|
+
start_time = stream_info.get("input_settings", {}).get("start_frame", 30)/stream_info.get("input_settings", {}).get("original_fps", 30)
|
622
|
+
stream_time_str = self._format_timestamp_for_video(start_time)
|
623
|
+
return stream_time_str
|
624
|
+
else:
|
625
|
+
return datetime.now(timezone.utc).strftime("%Y-%m-%d-%H:%M:%S.%f UTC")
|
626
|
+
|
627
|
+
if stream_info.get("input_settings", {}).get("start_frame", "na") != "na":
|
628
|
+
if frame_id:
|
629
|
+
start_time = int(frame_id)/stream_info.get("input_settings", {}).get("original_fps", 30)
|
630
|
+
else:
|
631
|
+
start_time = stream_info.get("input_settings", {}).get("start_frame", 30)/stream_info.get("input_settings", {}).get("original_fps", 30)
|
632
|
+
stream_time_str = self._format_timestamp_for_video(start_time)
|
633
|
+
return stream_time_str
|
634
|
+
else:
|
635
|
+
# For streams, use stream_time from stream_info
|
636
|
+
stream_time_str = stream_info.get("input_settings", {}).get("stream_info", {}).get("stream_time", "")
|
637
|
+
if stream_time_str:
|
638
|
+
# Parse the high precision timestamp string to get timestamp
|
639
|
+
try:
|
640
|
+
# Remove " UTC" suffix and parse
|
641
|
+
timestamp_str = stream_time_str.replace(" UTC", "")
|
642
|
+
dt = datetime.strptime(timestamp_str, "%Y-%m-%d-%H:%M:%S.%f")
|
643
|
+
timestamp = dt.replace(tzinfo=timezone.utc).timestamp()
|
644
|
+
return self._format_timestamp_for_stream(timestamp)
|
645
|
+
except:
|
646
|
+
# Fallback to current time if parsing fails
|
647
|
+
return self._format_timestamp_for_stream(time.time())
|
648
|
+
else:
|
649
|
+
return self._format_timestamp_for_stream(time.time())
|
650
|
+
|
651
|
+
def _get_start_timestamp_str(self, stream_info: Optional[Dict[str, Any]], precision=False) -> str:
|
652
|
+
"""Get formatted start timestamp for 'TOTAL SINCE' based on stream type."""
|
653
|
+
if not stream_info:
|
654
|
+
return "00:00:00"
|
655
|
+
if precision:
|
656
|
+
if stream_info.get("input_settings", {}).get("start_frame", "na") != "na":
|
657
|
+
return "00:00:00"
|
658
|
+
else:
|
659
|
+
return datetime.now(timezone.utc).strftime("%Y-%m-%d-%H:%M:%S.%f UTC")
|
660
|
+
|
661
|
+
if stream_info.get("input_settings", {}).get("start_frame", "na") != "na":
|
662
|
+
# If video format, start from 00:00:00
|
663
|
+
return "00:00:00"
|
664
|
+
else:
|
665
|
+
# For streams, use tracking start time or current time with minutes/seconds reset
|
666
|
+
if self._tracking_start_time is None:
|
667
|
+
# Try to extract timestamp from stream_time string
|
668
|
+
stream_time_str = stream_info.get("input_settings", {}).get("stream_info", {}).get("stream_time", "")
|
669
|
+
if stream_time_str:
|
670
|
+
try:
|
671
|
+
# Remove " UTC" suffix and parse
|
672
|
+
timestamp_str = stream_time_str.replace(" UTC", "")
|
673
|
+
dt = datetime.strptime(timestamp_str, "%Y-%m-%d-%H:%M:%S.%f")
|
674
|
+
self._tracking_start_time = dt.replace(tzinfo=timezone.utc).timestamp()
|
675
|
+
except:
|
676
|
+
# Fallback to current time if parsing fails
|
677
|
+
self._tracking_start_time = time.time()
|
678
|
+
else:
|
679
|
+
self._tracking_start_time = time.time()
|
680
|
+
|
681
|
+
dt = datetime.fromtimestamp(self._tracking_start_time, tz=timezone.utc)
|
682
|
+
# Reset minutes and seconds to 00:00 for "TOTAL SINCE" format
|
683
|
+
dt = dt.replace(minute=0, second=0, microsecond=0)
|
684
|
+
return dt.strftime('%Y:%m:%d %H:%M:%S')
|
685
|
+
|
686
|
+
|
687
|
+
def _count_categories(self, detections: list, config: SmokerDetectionConfig) -> dict:
|
688
|
+
"""
|
689
|
+
Count the number of detections per category and return a summary dict.
|
690
|
+
The detections list is expected to have 'track_id' (from tracker), 'category', 'bounding_box', etc.
|
691
|
+
Output structure will include 'track_id' for each detection as per AdvancedTracker output.
|
692
|
+
"""
|
693
|
+
counts = {}
|
694
|
+
for det in detections:
|
695
|
+
cat = det.get('category', 'unknown')
|
696
|
+
counts[cat] = counts.get(cat, 0) + 1
|
697
|
+
# Each detection dict will now include 'track_id' (and possibly 'frame_id')
|
698
|
+
return {
|
699
|
+
"total_count": sum(counts.values()),
|
700
|
+
"per_category_count": counts,
|
701
|
+
"detections": [
|
702
|
+
{
|
703
|
+
"bounding_box": det.get("bounding_box"),
|
704
|
+
"category": det.get("category"),
|
705
|
+
"confidence": det.get("confidence"),
|
706
|
+
"track_id": det.get("track_id"),
|
707
|
+
"frame_id": det.get("frame_id")
|
708
|
+
}
|
709
|
+
for det in detections
|
710
|
+
]
|
711
|
+
}
|
712
|
+
|
713
|
+
def _extract_predictions(self, detections: list) -> List[Dict[str, Any]]:
|
714
|
+
"""
|
715
|
+
Extract prediction details for output (category, confidence, bounding box).
|
716
|
+
"""
|
717
|
+
return [
|
718
|
+
{
|
719
|
+
"category": det.get("category", "unknown"),
|
720
|
+
"confidence": det.get("confidence", 0.0),
|
721
|
+
"bounding_box": det.get("bounding_box", {})
|
722
|
+
}
|
723
|
+
for det in detections
|
724
|
+
]
|
725
|
+
|
726
|
+
# ------------------------------------------------------------------ #
|
727
|
+
# Canonical ID helpers #
|
728
|
+
# ------------------------------------------------------------------ #
|
729
|
+
def _compute_iou(self, box1: Any, box2: Any) -> float:
|
730
|
+
"""Compute IoU between two bounding boxes which may be dicts or lists.
|
731
|
+
Falls back to 0 when insufficient data is available."""
|
732
|
+
|
733
|
+
# Helper to convert bbox (dict or list) to [x1, y1, x2, y2]
|
734
|
+
def _bbox_to_list(bbox):
|
735
|
+
if bbox is None:
|
736
|
+
return []
|
737
|
+
if isinstance(bbox, list):
|
738
|
+
return bbox[:4] if len(bbox) >= 4 else []
|
739
|
+
if isinstance(bbox, dict):
|
740
|
+
if "xmin" in bbox:
|
741
|
+
return [bbox["xmin"], bbox["ymin"], bbox["xmax"], bbox["ymax"]]
|
742
|
+
if "x1" in bbox:
|
743
|
+
return [bbox["x1"], bbox["y1"], bbox["x2"], bbox["y2"]]
|
744
|
+
# Fallback: first four numeric values
|
745
|
+
values = [v for v in bbox.values() if isinstance(v, (int, float))]
|
746
|
+
return values[:4] if len(values) >= 4 else []
|
747
|
+
return []
|
748
|
+
|
749
|
+
l1 = _bbox_to_list(box1)
|
750
|
+
l2 = _bbox_to_list(box2)
|
751
|
+
if len(l1) < 4 or len(l2) < 4:
|
752
|
+
return 0.0
|
753
|
+
x1_min, y1_min, x1_max, y1_max = l1
|
754
|
+
x2_min, y2_min, x2_max, y2_max = l2
|
755
|
+
|
756
|
+
# Ensure correct order
|
757
|
+
x1_min, x1_max = min(x1_min, x1_max), max(x1_min, x1_max)
|
758
|
+
y1_min, y1_max = min(y1_min, y1_max), max(y1_min, y1_max)
|
759
|
+
x2_min, x2_max = min(x2_min, x2_max), max(x2_min, x2_max)
|
760
|
+
y2_min, y2_max = min(y2_min, y2_max), max(y2_min, y2_max)
|
761
|
+
|
762
|
+
inter_x_min = max(x1_min, x2_min)
|
763
|
+
inter_y_min = max(y1_min, y2_min)
|
764
|
+
inter_x_max = min(x1_max, x2_max)
|
765
|
+
inter_y_max = min(y1_max, y2_max)
|
766
|
+
|
767
|
+
inter_w = max(0.0, inter_x_max - inter_x_min)
|
768
|
+
inter_h = max(0.0, inter_y_max - inter_y_min)
|
769
|
+
inter_area = inter_w * inter_h
|
770
|
+
|
771
|
+
area1 = (x1_max - x1_min) * (y1_max - y1_min)
|
772
|
+
area2 = (x2_max - x2_min) * (y2_max - y2_min)
|
773
|
+
union_area = area1 + area2 - inter_area
|
774
|
+
|
775
|
+
return (inter_area / union_area) if union_area > 0 else 0.0
|
776
|
+
|
777
|
+
def _merge_or_register_track(self, raw_id: Any, bbox: Any) -> Any:
|
778
|
+
"""Return a stable canonical ID for a raw tracker ID, merging fragmented
|
779
|
+
tracks when IoU and temporal constraints indicate they represent the
|
780
|
+
same physical."""
|
781
|
+
if raw_id is None or bbox is None:
|
782
|
+
# Nothing to merge
|
783
|
+
return raw_id
|
784
|
+
|
785
|
+
now = time.time()
|
786
|
+
|
787
|
+
# Fast path – raw_id already mapped
|
788
|
+
if raw_id in self._track_aliases:
|
789
|
+
canonical_id = self._track_aliases[raw_id]
|
790
|
+
track_info = self._canonical_tracks.get(canonical_id)
|
791
|
+
if track_info is not None:
|
792
|
+
track_info["last_bbox"] = bbox
|
793
|
+
track_info["last_update"] = now
|
794
|
+
track_info["raw_ids"].add(raw_id)
|
795
|
+
return canonical_id
|
796
|
+
|
797
|
+
# Attempt to merge with an existing canonical track
|
798
|
+
for canonical_id, info in self._canonical_tracks.items():
|
799
|
+
# Only consider recently updated tracks
|
800
|
+
if now - info["last_update"] > self._track_merge_time_window:
|
801
|
+
continue
|
802
|
+
iou = self._compute_iou(bbox, info["last_bbox"])
|
803
|
+
if iou >= self._track_merge_iou_threshold:
|
804
|
+
# Merge
|
805
|
+
self._track_aliases[raw_id] = canonical_id
|
806
|
+
info["last_bbox"] = bbox
|
807
|
+
info["last_update"] = now
|
808
|
+
info["raw_ids"].add(raw_id)
|
809
|
+
return canonical_id
|
810
|
+
|
811
|
+
# No match – register new canonical track
|
812
|
+
canonical_id = raw_id
|
813
|
+
self._track_aliases[raw_id] = canonical_id
|
814
|
+
self._canonical_tracks[canonical_id] = {
|
815
|
+
"last_bbox": bbox,
|
816
|
+
"last_update": now,
|
817
|
+
"raw_ids": {raw_id},
|
818
|
+
}
|
819
|
+
return canonical_id
|
820
|
+
|
821
|
+
def _format_timestamp(self, timestamp: float) -> str:
|
822
|
+
"""Format a timestamp for human-readable output."""
|
823
|
+
return datetime.fromtimestamp(timestamp, timezone.utc).strftime('%Y-%m-%d %H:%M:%S UTC')
|
824
|
+
|
825
|
+
def _get_tracking_start_time(self) -> str:
|
826
|
+
"""Get the tracking start time, formatted as a string."""
|
827
|
+
if self._tracking_start_time is None:
|
828
|
+
return "N/A"
|
829
|
+
return self._format_timestamp(self._tracking_start_time)
|
830
|
+
|
831
|
+
def _set_tracking_start_time(self) -> None:
|
832
|
+
"""Set the tracking start time to the current time."""
|
833
|
+
self._tracking_start_time = time.time()
|