openvisionkit 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openvisionkit/__init__.py +1 -0
- openvisionkit/_version.py +24 -0
- openvisionkit/capture/draw_object.py +296 -0
- openvisionkit/capture/image_template.py +61 -0
- openvisionkit/capture/screen_capture.py +13 -0
- openvisionkit/capture/video_recorder.py +128 -0
- openvisionkit/capture/video_template.py +336 -0
- openvisionkit/lib/classifier.py +186 -0
- openvisionkit/lib/face_detector.py +587 -0
- openvisionkit/lib/face_mesh_detector.py +913 -0
- openvisionkit/lib/form_detector.py +465 -0
- openvisionkit/lib/form_roi_annotator.py +679 -0
- openvisionkit/lib/form_roi_detector.py +1078 -0
- openvisionkit/lib/fps_counter.py +38 -0
- openvisionkit/lib/hair_segmentation.py +298 -0
- openvisionkit/lib/hand_detector.py +1230 -0
- openvisionkit/lib/image_detector.py +1095 -0
- openvisionkit/lib/object_detector.py +401 -0
- openvisionkit/lib/pose_detector.py +919 -0
- openvisionkit/lib/selfie_segmentation.py +528 -0
- openvisionkit/lib/text_detector.py +1229 -0
- openvisionkit/utility/live_plot.py +141 -0
- openvisionkit/utility/vision_utilis.py +871 -0
- openvisionkit-0.4.0.dist-info/METADATA +1018 -0
- openvisionkit-0.4.0.dist-info/RECORD +26 -0
- openvisionkit-0.4.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,336 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
import ctypes
|
|
3
|
+
import time
|
|
4
|
+
from collections.abc import Callable
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
import cv2
|
|
9
|
+
|
|
10
|
+
from openvisionkit.capture.video_recorder import VideoRecorder
|
|
11
|
+
from openvisionkit.lib.fps_counter import FPSCounter
|
|
12
|
+
|
|
13
|
+
with contextlib.suppress(Exception):
|
|
14
|
+
ctypes.windll.user32.SetProcessDPIAware()
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class KeyEventManager:
|
|
18
|
+
def __init__(self):
|
|
19
|
+
self.handlers = {}
|
|
20
|
+
|
|
21
|
+
def register(self, key, callback):
|
|
22
|
+
"""
|
|
23
|
+
key: ord('r'), ord('p'), etc.
|
|
24
|
+
callback(frame, state)
|
|
25
|
+
"""
|
|
26
|
+
self.handlers[key] = callback
|
|
27
|
+
|
|
28
|
+
def handle(self, key, frame, state):
|
|
29
|
+
if key in self.handlers:
|
|
30
|
+
self.handlers[key](frame, state)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def save_screenshot(frame, output_dir="screenshots", prefix="capture"):
|
|
34
|
+
"""Saves a single frame as a timestamped PNG file.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
frame (numpy.ndarray): BGR image to save.
|
|
38
|
+
output_dir (str): Directory where the file is written. Created if absent.
|
|
39
|
+
Default is 'screenshots'.
|
|
40
|
+
prefix (str): Filename prefix before the timestamp. Default is 'capture'.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
str: Absolute path of the saved PNG file.
|
|
44
|
+
"""
|
|
45
|
+
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
|
46
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
|
47
|
+
filename = Path(output_dir) / f"{prefix}_{timestamp}.png"
|
|
48
|
+
cv2.imwrite(str(filename), frame)
|
|
49
|
+
print(f"📸 Screenshot saved: {filename}")
|
|
50
|
+
return str(filename)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def video_capture_template(
|
|
54
|
+
video_source: int | str = 0,
|
|
55
|
+
loop_forever: bool = True,
|
|
56
|
+
custom_logic: Callable[[cv2.typing.MatLike], cv2.typing.MatLike] | None = None,
|
|
57
|
+
state: dict | None = None,
|
|
58
|
+
key_manager: KeyEventManager | None = None,
|
|
59
|
+
window_name: str = "Demo",
|
|
60
|
+
show_window: bool = True,
|
|
61
|
+
resolution: tuple[int, int] = (1280, 720),
|
|
62
|
+
center_window: bool = True,
|
|
63
|
+
draw_fps: bool = True,
|
|
64
|
+
fps=15,
|
|
65
|
+
# MOUSE CALLBACK OPTION
|
|
66
|
+
mouse_callback: Callable | None = None,
|
|
67
|
+
mouse_callback_params: dict | None = None,
|
|
68
|
+
# VIDEO RECORDING OPTIONS
|
|
69
|
+
enable_auto_recording: bool = False,
|
|
70
|
+
enable_manual_recording: bool = False,
|
|
71
|
+
record_format="mp4", # "mp4" | "gif"
|
|
72
|
+
# SCREENSHOT OPTIONS
|
|
73
|
+
enable_screenshot: bool = False,
|
|
74
|
+
screenshot_output_dir: str = "screenshots",
|
|
75
|
+
screenshot_prefix: str = "capture",
|
|
76
|
+
auto_screenshot_after_seconds: float | None = None,
|
|
77
|
+
auto_screenshot_repeat: bool = False,
|
|
78
|
+
):
|
|
79
|
+
"""
|
|
80
|
+
REUSABLE TEMPLATE for all OpenCV video demos.
|
|
81
|
+
|
|
82
|
+
New configurable features:
|
|
83
|
+
- resolution: Set camera resolution (e.g. 1280x720, 1920x1080)
|
|
84
|
+
- center_window: Automatically centers the OpenCV window on your screen using pyautogui
|
|
85
|
+
|
|
86
|
+
How to use:
|
|
87
|
+
1. Define your own logic as a function that takes a frame and returns the processed frame.
|
|
88
|
+
2. Call this template with the video source and your logic function.
|
|
89
|
+
3. FPS counter, ESC exit, resolution control, and window centering are already handled.
|
|
90
|
+
|
|
91
|
+
Parameters:
|
|
92
|
+
video_source (int or str):
|
|
93
|
+
- int (e.g. 0, 1, 2...) → camera index
|
|
94
|
+
- str → path to video file (mp4, avi, etc.)
|
|
95
|
+
loop_forever (bool): If True, loops the video file when it ends. Default = True
|
|
96
|
+
screen_capture (bool): If True, captures a portion of the screen instead of webcam/video. Default = False
|
|
97
|
+
screen_capture_bbox (tuple): Bounding box for screen capture (left, top, right, bottom). Default = (300, 300, 1500, 1000)
|
|
98
|
+
custom_logic (callable, optional):
|
|
99
|
+
Function that receives the frame and returns the modified frame.
|
|
100
|
+
This is where you put ALL your own logic (blink detection, face detection, etc.).
|
|
101
|
+
state (dict, optional):
|
|
102
|
+
A dictionary that is passed to key handlers and can be used to store game state, scores, or any other information you need to persist across frames and key events.
|
|
103
|
+
Default is None, but you can initialize it with your own dictionary before passing to the template. For example:
|
|
104
|
+
state = {'score': [0, 0], 'game_over': False}
|
|
105
|
+
key_manager (KeyEventManager, optional): An instance of KeyEventManager to handle key events. Default = None
|
|
106
|
+
show_window (bool): If True, displays the video window. Default = True
|
|
107
|
+
window_name (str): Name of the OpenCV window.
|
|
108
|
+
resolution (tuple[int, int]): Desired camera resolution (width, height). Default = (1280, 720)
|
|
109
|
+
center_window (bool): If True, automatically centers the window on screen. Default = True
|
|
110
|
+
draw_fps (bool): If True, calculates and displays FPS on the video feed. Default = True
|
|
111
|
+
fps: Frame rate for recording (only applies if enable_auto_recording is True). Default = 15
|
|
112
|
+
|
|
113
|
+
# MOUSE CALLBACK OPTION
|
|
114
|
+
mouse_callback (callable, optional): Function to handle mouse events. Default = None
|
|
115
|
+
mouse_callback_params (dict, optional): Additional parameters to pass to the mouse callback function. Default = None
|
|
116
|
+
|
|
117
|
+
# VIDEO RECORDING OPTIONS
|
|
118
|
+
enable_auto_recording (bool): If True, records the video feed to an output file automatically. Default = False
|
|
119
|
+
enable_manual_recording (bool): If True, allows starting/stopping recording with a key press (e.g. 'r' or 'R'). Default = False
|
|
120
|
+
record_format (str): Format for recording output ("mp4" or "gif"). Default = "mp4"
|
|
121
|
+
|
|
122
|
+
# SCREENSHOT OPTIONS
|
|
123
|
+
enable_screenshot (bool): If True, allows taking screenshots by pressing 's'. Default = True
|
|
124
|
+
screenshot_output_dir (str): Directory where screenshots will be saved. Default = "screenshots"
|
|
125
|
+
screenshot_prefix (str): Prefix for screenshot filenames. Default = "capture"
|
|
126
|
+
auto_screenshot_after_seconds (float, optional): If set, automatically takes a screenshot after this many seconds. Default = None (disabled)
|
|
127
|
+
auto_screenshot_repeat (bool): If True and auto_screenshot_after_seconds is set, continues to take screenshots at the specified interval. Default = False
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
Usasge:
|
|
131
|
+
|
|
132
|
+
1. Screenshot:
|
|
133
|
+
For repeated auto screenshots every 5 seconds:
|
|
134
|
+
video_capture_template(
|
|
135
|
+
video_source=0,
|
|
136
|
+
custom_logic=my_logic,
|
|
137
|
+
enable_screenshot=True,
|
|
138
|
+
auto_screenshot_after_seconds=5,
|
|
139
|
+
auto_screenshot_repeat=False,
|
|
140
|
+
)
|
|
141
|
+
For manual screenshots with 's' key:
|
|
142
|
+
video_capture_template(
|
|
143
|
+
video_source=0,
|
|
144
|
+
custom_logic=my_logic,
|
|
145
|
+
enable_screenshot=True,
|
|
146
|
+
auto_screenshot_after_seconds=None,
|
|
147
|
+
auto_screenshot_repeat=False,
|
|
148
|
+
)
|
|
149
|
+
"""
|
|
150
|
+
cap = cv2.VideoCapture(video_source)
|
|
151
|
+
|
|
152
|
+
if not cap.isOpened():
|
|
153
|
+
print(f"Error: Could not open video source '{video_source}'")
|
|
154
|
+
return
|
|
155
|
+
|
|
156
|
+
frame_width, frame_height = resolution
|
|
157
|
+
cap.set(cv2.CAP_PROP_FRAME_WIDTH, frame_width)
|
|
158
|
+
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, frame_height)
|
|
159
|
+
|
|
160
|
+
window_centered = False
|
|
161
|
+
first_frame_rendered = False
|
|
162
|
+
|
|
163
|
+
if state is None:
|
|
164
|
+
state = {}
|
|
165
|
+
|
|
166
|
+
# ── auto recording state ──────────────────────────────────────────────
|
|
167
|
+
auto_recorder: VideoRecorder | None = None
|
|
168
|
+
auto_recorder_started = False
|
|
169
|
+
|
|
170
|
+
# ── manual recording state ────────────────────────────────────────────
|
|
171
|
+
manual_recording = False # True while the user is recording
|
|
172
|
+
manual_recorder: VideoRecorder | None = None
|
|
173
|
+
|
|
174
|
+
if draw_fps:
|
|
175
|
+
fps_counter = FPSCounter()
|
|
176
|
+
|
|
177
|
+
current_fps = fps # will be updated each frame when draw_fps is True
|
|
178
|
+
|
|
179
|
+
start_time = time.time()
|
|
180
|
+
last_auto_screenshot_time = start_time
|
|
181
|
+
auto_screenshot_done = False
|
|
182
|
+
|
|
183
|
+
if show_window:
|
|
184
|
+
cv2.namedWindow(window_name, cv2.WINDOW_NORMAL | cv2.WINDOW_GUI_EXPANDED)
|
|
185
|
+
cv2.resizeWindow(window_name, frame_width, frame_height)
|
|
186
|
+
if mouse_callback is not None:
|
|
187
|
+
cv2.setMouseCallback(window_name, mouse_callback, mouse_callback_params)
|
|
188
|
+
|
|
189
|
+
while True:
|
|
190
|
+
if loop_forever and cap.get(cv2.CAP_PROP_POS_FRAMES) >= cap.get(
|
|
191
|
+
cv2.CAP_PROP_FRAME_COUNT
|
|
192
|
+
):
|
|
193
|
+
cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
|
194
|
+
|
|
195
|
+
ret, frame = cap.read()
|
|
196
|
+
if not ret:
|
|
197
|
+
print("End of video stream or failed to read frame.")
|
|
198
|
+
break
|
|
199
|
+
|
|
200
|
+
if custom_logic is not None:
|
|
201
|
+
frame = custom_logic(frame)
|
|
202
|
+
|
|
203
|
+
if draw_fps:
|
|
204
|
+
frame, current_fps = fps_counter.update(frame)
|
|
205
|
+
|
|
206
|
+
# ── AUTO RECORDING ────────────────────────────────────────────────
|
|
207
|
+
if enable_auto_recording:
|
|
208
|
+
if auto_recorder is None:
|
|
209
|
+
safe_fps = current_fps if current_fps and current_fps > 0 else 10
|
|
210
|
+
print("Initializing auto-recorder with FPS:", safe_fps)
|
|
211
|
+
auto_recorder = VideoRecorder(output_format=record_format, fps=safe_fps)
|
|
212
|
+
|
|
213
|
+
if not auto_recorder_started:
|
|
214
|
+
auto_recorder.start(frame.shape)
|
|
215
|
+
auto_recorder_started = True
|
|
216
|
+
|
|
217
|
+
auto_recorder.write(frame)
|
|
218
|
+
|
|
219
|
+
cv2.putText(
|
|
220
|
+
frame,
|
|
221
|
+
"REC (AUTO)",
|
|
222
|
+
(20, 80),
|
|
223
|
+
cv2.FONT_HERSHEY_SIMPLEX,
|
|
224
|
+
1,
|
|
225
|
+
(0, 255, 0),
|
|
226
|
+
2,
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
# ── MANUAL RECORDING ─────────────────────────────────────────────
|
|
230
|
+
if enable_manual_recording and manual_recording:
|
|
231
|
+
if manual_recorder is None:
|
|
232
|
+
# Initialise lazily the first time R is pressed
|
|
233
|
+
safe_fps = current_fps if current_fps and current_fps > 0 else 10
|
|
234
|
+
print("Initializing manual recorder with FPS:", safe_fps)
|
|
235
|
+
manual_recorder = VideoRecorder(
|
|
236
|
+
output_format=record_format, fps=safe_fps
|
|
237
|
+
)
|
|
238
|
+
manual_recorder.start(frame.shape)
|
|
239
|
+
|
|
240
|
+
manual_recorder.write(frame)
|
|
241
|
+
|
|
242
|
+
cv2.putText(
|
|
243
|
+
frame,
|
|
244
|
+
"REC (MANUAL)",
|
|
245
|
+
(20, 120),
|
|
246
|
+
cv2.FONT_HERSHEY_SIMPLEX,
|
|
247
|
+
1,
|
|
248
|
+
(0, 0, 255),
|
|
249
|
+
2,
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
# ── AUTO SCREENSHOT ───────────────────────────────────────────────
|
|
253
|
+
if enable_screenshot and auto_screenshot_after_seconds is not None:
|
|
254
|
+
now = time.time()
|
|
255
|
+
if auto_screenshot_repeat:
|
|
256
|
+
if now - last_auto_screenshot_time >= auto_screenshot_after_seconds:
|
|
257
|
+
save_screenshot(
|
|
258
|
+
frame,
|
|
259
|
+
output_dir=screenshot_output_dir,
|
|
260
|
+
prefix=screenshot_prefix,
|
|
261
|
+
)
|
|
262
|
+
last_auto_screenshot_time = now
|
|
263
|
+
else:
|
|
264
|
+
if (
|
|
265
|
+
not auto_screenshot_done
|
|
266
|
+
and now - start_time >= auto_screenshot_after_seconds
|
|
267
|
+
):
|
|
268
|
+
save_screenshot(
|
|
269
|
+
frame,
|
|
270
|
+
output_dir=screenshot_output_dir,
|
|
271
|
+
prefix=screenshot_prefix,
|
|
272
|
+
)
|
|
273
|
+
auto_screenshot_done = True
|
|
274
|
+
|
|
275
|
+
if show_window:
|
|
276
|
+
cv2.imshow(window_name, frame)
|
|
277
|
+
|
|
278
|
+
if center_window and not window_centered and first_frame_rendered:
|
|
279
|
+
try:
|
|
280
|
+
import pyautogui # noqa: PLC0415
|
|
281
|
+
|
|
282
|
+
screen_width, screen_height = pyautogui.size()
|
|
283
|
+
x = (screen_width - frame_width) // 2
|
|
284
|
+
y = (screen_height - frame_height) // 2
|
|
285
|
+
cv2.moveWindow(window_name, x, y)
|
|
286
|
+
except Exception:
|
|
287
|
+
pass # headless / no display — skip centering
|
|
288
|
+
window_centered = True
|
|
289
|
+
|
|
290
|
+
first_frame_rendered = True
|
|
291
|
+
|
|
292
|
+
key = cv2.waitKey(1) & 0xFF
|
|
293
|
+
|
|
294
|
+
# ESC → exit
|
|
295
|
+
if key == 27:
|
|
296
|
+
print("Exiting cleanly...")
|
|
297
|
+
break
|
|
298
|
+
|
|
299
|
+
# Custom key handlers
|
|
300
|
+
if key_manager:
|
|
301
|
+
key_manager.handle(key, frame, state)
|
|
302
|
+
|
|
303
|
+
# S → screenshot
|
|
304
|
+
if enable_screenshot and key in [ord("s"), ord("S")]:
|
|
305
|
+
save_screenshot(
|
|
306
|
+
frame, output_dir=screenshot_output_dir, prefix=screenshot_prefix
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
# R → toggle manual recording on/off
|
|
310
|
+
if enable_manual_recording and key in [ord("r"), ord("R")]:
|
|
311
|
+
manual_recording = not manual_recording
|
|
312
|
+
|
|
313
|
+
if manual_recording:
|
|
314
|
+
# ── START ──────────────────────────────────────────────
|
|
315
|
+
print("🎥 Manual recording: ON")
|
|
316
|
+
# Recorder is created fresh each time so a new file is opened
|
|
317
|
+
manual_recorder = None # will be lazily created above on next frame
|
|
318
|
+
|
|
319
|
+
else:
|
|
320
|
+
# ── STOP ───────────────────────────────────────────────
|
|
321
|
+
print("⏹️ Manual recording: OFF — saving…")
|
|
322
|
+
if manual_recorder is not None:
|
|
323
|
+
manual_recorder.stop()
|
|
324
|
+
manual_recorder = None
|
|
325
|
+
|
|
326
|
+
# ── CLEANUP ───────────────────────────────────────────────────────────
|
|
327
|
+
cap.release()
|
|
328
|
+
cv2.destroyAllWindows()
|
|
329
|
+
|
|
330
|
+
if auto_recorder:
|
|
331
|
+
print("Stopping auto-recorder…")
|
|
332
|
+
auto_recorder.stop()
|
|
333
|
+
|
|
334
|
+
if manual_recorder:
|
|
335
|
+
print("Stopping manual recorder (cleanup)…")
|
|
336
|
+
manual_recorder.stop()
|
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Simple Classifier for Teachable Machine .h5 models
|
|
3
|
+
Works well with TensorFlow 2.15 / 2.16 on Apple Silicon
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import os
|
|
7
|
+
|
|
8
|
+
import cv2
|
|
9
|
+
import numpy as np
|
|
10
|
+
import tensorflow as tf
|
|
11
|
+
from tensorflow.keras.models import load_model
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Classifier:
|
|
15
|
+
def __init__(self, model_path: str, labels_path: str):
|
|
16
|
+
if not os.path.exists(model_path):
|
|
17
|
+
raise FileNotFoundError(f"Model not found: {model_path}")
|
|
18
|
+
if not os.path.exists(labels_path):
|
|
19
|
+
raise FileNotFoundError(f"Labels not found: {labels_path}")
|
|
20
|
+
|
|
21
|
+
print(f"Loading model: {model_path}")
|
|
22
|
+
self.model = load_model(model_path, compile=False)
|
|
23
|
+
|
|
24
|
+
with open(labels_path, encoding="utf-8") as f:
|
|
25
|
+
self.labels = [line.strip() for line in f.readlines() if line.strip()]
|
|
26
|
+
|
|
27
|
+
print(f"Model loaded | TF {tf.__version__} | {len(self.labels)} labels")
|
|
28
|
+
|
|
29
|
+
self.data = np.ndarray(shape=(1, 224, 224, 3), dtype=np.float32)
|
|
30
|
+
|
|
31
|
+
def preprocess(self, img: np.ndarray) -> np.ndarray:
|
|
32
|
+
resized = cv2.resize(img, (224, 224))
|
|
33
|
+
array = np.asarray(resized, dtype=np.float32)
|
|
34
|
+
return (array / 127.0) - 1.0
|
|
35
|
+
|
|
36
|
+
def predict(self, img: np.ndarray) -> tuple[list[float], int, str]:
|
|
37
|
+
processed = self.preprocess(img)
|
|
38
|
+
self.data[0] = processed
|
|
39
|
+
|
|
40
|
+
predictions = self.model.predict(self.data, verbose=0)
|
|
41
|
+
probs = predictions[0].tolist()
|
|
42
|
+
|
|
43
|
+
index = int(np.argmax(predictions))
|
|
44
|
+
label = self.labels[index] if index < len(self.labels) else f"Class {index}"
|
|
45
|
+
|
|
46
|
+
return probs, index, label
|
|
47
|
+
|
|
48
|
+
def getPrediction(
|
|
49
|
+
self,
|
|
50
|
+
img: np.ndarray,
|
|
51
|
+
draw: bool = True,
|
|
52
|
+
pos: tuple[int, int] = (30, 50),
|
|
53
|
+
scale: float = 1.5,
|
|
54
|
+
color: tuple[int, int, int] = (0, 255, 0),
|
|
55
|
+
thickness: int = 2,
|
|
56
|
+
) -> tuple[list[float], int]:
|
|
57
|
+
probs, index, label = self.predict(img)
|
|
58
|
+
|
|
59
|
+
if draw:
|
|
60
|
+
cv2.putText(
|
|
61
|
+
img, label, pos, cv2.FONT_HERSHEY_COMPLEX, scale, color, thickness
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
return probs, index
|
|
65
|
+
|
|
66
|
+
def get_label(self, index: int) -> str:
|
|
67
|
+
return (
|
|
68
|
+
self.labels[index]
|
|
69
|
+
if 0 <= index < len(self.labels)
|
|
70
|
+
else f"Unknown ({index})"
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# ─────────────────────────── NEW METHODS ───────────────────────────
|
|
74
|
+
|
|
75
|
+
def get_confidence(self, probs: list, index: int) -> float:
|
|
76
|
+
"""Return the confidence percentage for a specific class index.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
probs: Probability list from predict().
|
|
80
|
+
index: Class index to query.
|
|
81
|
+
Returns:
|
|
82
|
+
float: 0.0–100.0
|
|
83
|
+
"""
|
|
84
|
+
if not probs or index >= len(probs):
|
|
85
|
+
return 0.0
|
|
86
|
+
return probs[index] * 100.0
|
|
87
|
+
|
|
88
|
+
def predict_top_n(self, img: np.ndarray, n: int = 3):
|
|
89
|
+
"""Return the top-N predictions sorted by descending confidence.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
img: BGR numpy array.
|
|
93
|
+
n: Number of top predictions to return.
|
|
94
|
+
Returns:
|
|
95
|
+
List[dict]: [{'label': str, 'index': int, 'confidence': float}, ...]
|
|
96
|
+
"""
|
|
97
|
+
probs, _, _ = self.predict(img)
|
|
98
|
+
indices = sorted(range(len(probs)), key=lambda i: probs[i], reverse=True)[:n]
|
|
99
|
+
return [
|
|
100
|
+
{"label": self.get_label(i), "index": i, "confidence": probs[i] * 100.0}
|
|
101
|
+
for i in indices
|
|
102
|
+
]
|
|
103
|
+
|
|
104
|
+
def get_all_predictions(self, probs: list):
|
|
105
|
+
"""Return all class predictions paired with their labels.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
probs: Probability list from predict().
|
|
109
|
+
Returns:
|
|
110
|
+
List[dict]: [{'label': str, 'index': int, 'confidence': float}] sorted desc.
|
|
111
|
+
"""
|
|
112
|
+
return sorted(
|
|
113
|
+
[
|
|
114
|
+
{"label": self.get_label(i), "index": i, "confidence": p * 100.0}
|
|
115
|
+
for i, p in enumerate(probs)
|
|
116
|
+
],
|
|
117
|
+
key=lambda x: x["confidence"],
|
|
118
|
+
reverse=True,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
def is_confident(self, probs: list, threshold: float = 70.0) -> bool:
|
|
122
|
+
"""Return True if the top prediction confidence meets the threshold.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
probs: Probability list from predict().
|
|
126
|
+
threshold: Minimum confidence percentage (default 70 %).
|
|
127
|
+
Returns:
|
|
128
|
+
bool
|
|
129
|
+
"""
|
|
130
|
+
if not probs:
|
|
131
|
+
return False
|
|
132
|
+
return max(probs) * 100.0 >= threshold
|
|
133
|
+
|
|
134
|
+
def predict_batch(self, images: list):
|
|
135
|
+
"""Run predict() on a list of images and return all results.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
images: List of BGR numpy arrays.
|
|
139
|
+
Returns:
|
|
140
|
+
List[dict]: [{'label': str, 'index': int, 'confidence': float}]
|
|
141
|
+
"""
|
|
142
|
+
results = []
|
|
143
|
+
for img in images:
|
|
144
|
+
probs, index, label = self.predict(img)
|
|
145
|
+
results.append(
|
|
146
|
+
{
|
|
147
|
+
"label": label,
|
|
148
|
+
"index": index,
|
|
149
|
+
"confidence": probs[index] * 100.0,
|
|
150
|
+
}
|
|
151
|
+
)
|
|
152
|
+
return results
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
# ====================== Quick Test ======================
|
|
156
|
+
|
|
157
|
+
if __name__ == "__main__":
|
|
158
|
+
MODEL_PATH = "hand-gesture/hand-sign-detection/model/keras_model.h5"
|
|
159
|
+
LABELS_PATH = "hand-gesture/hand-sign-detection/model/labels.txt"
|
|
160
|
+
|
|
161
|
+
classifier = Classifier(MODEL_PATH, LABELS_PATH)
|
|
162
|
+
|
|
163
|
+
cap = cv2.VideoCapture(0)
|
|
164
|
+
if not cap.isOpened():
|
|
165
|
+
print("Camera not opened. Try changing to cv2.VideoCapture(1)")
|
|
166
|
+
exit()
|
|
167
|
+
|
|
168
|
+
print("Press 'q' to quit")
|
|
169
|
+
|
|
170
|
+
while True:
|
|
171
|
+
ret, frame = cap.read()
|
|
172
|
+
if not ret:
|
|
173
|
+
break
|
|
174
|
+
|
|
175
|
+
probs, idx = classifier.getPrediction(
|
|
176
|
+
frame, draw=True, scale=1.7, color=(0, 255, 100)
|
|
177
|
+
)
|
|
178
|
+
conf = probs[idx] * 100
|
|
179
|
+
print(f"→ {classifier.get_label(idx)} | Confidence: {conf:.1f}%")
|
|
180
|
+
|
|
181
|
+
cv2.imshow("Classifier", frame)
|
|
182
|
+
if cv2.waitKey(1) & 0xFF == ord("q"):
|
|
183
|
+
break
|
|
184
|
+
|
|
185
|
+
cap.release()
|
|
186
|
+
cv2.destroyAllWindows()
|