alchemydetect 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alchemydetect/__init__.py +0 -0
- alchemydetect/app.py +23 -0
- alchemydetect/core/__init__.py +0 -0
- alchemydetect/core/config_builder.py +90 -0
- alchemydetect/core/dataset_utils.py +62 -0
- alchemydetect/core/inferencer.py +67 -0
- alchemydetect/core/model_catalog.py +56 -0
- alchemydetect/core/trainer.py +67 -0
- alchemydetect/gui/__init__.py +0 -0
- alchemydetect/gui/dialogs.py +81 -0
- alchemydetect/gui/image_viewer.py +61 -0
- alchemydetect/gui/inference_tab.py +249 -0
- alchemydetect/gui/log_viewer.py +22 -0
- alchemydetect/gui/loss_plot.py +31 -0
- alchemydetect/gui/main_window.py +25 -0
- alchemydetect/gui/train_tab.py +289 -0
- alchemydetect/workers/__init__.py +0 -0
- alchemydetect/workers/inference_worker.py +84 -0
- alchemydetect/workers/train_worker.py +147 -0
- alchemydetect-0.1.0.dist-info/METADATA +99 -0
- alchemydetect-0.1.0.dist-info/RECORD +25 -0
- alchemydetect-0.1.0.dist-info/WHEEL +5 -0
- alchemydetect-0.1.0.dist-info/entry_points.txt +2 -0
- alchemydetect-0.1.0.dist-info/licenses/LICENSE +21 -0
- alchemydetect-0.1.0.dist-info/top_level.txt +1 -0
|
File without changes
|
alchemydetect/app.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""Entry point for the AlchemyDetect application."""
|
|
2
|
+
|
|
3
|
+
import multiprocessing
|
|
4
|
+
import sys
|
|
5
|
+
|
|
6
|
+
from PyQt6.QtWidgets import QApplication
|
|
7
|
+
|
|
8
|
+
from alchemydetect.gui.main_window import MainWindow
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def main():
|
|
12
|
+
app = QApplication(sys.argv)
|
|
13
|
+
app.setApplicationName("AlchemyDetect")
|
|
14
|
+
|
|
15
|
+
window = MainWindow()
|
|
16
|
+
window.show()
|
|
17
|
+
|
|
18
|
+
sys.exit(app.exec())
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
if __name__ == "__main__":
|
|
22
|
+
multiprocessing.freeze_support()
|
|
23
|
+
main()
|
|
File without changes
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
"""Build Detectron2 config from user-specified parameters."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from detectron2 import model_zoo
|
|
5
|
+
from detectron2.config import get_cfg
|
|
6
|
+
|
|
7
|
+
from .dataset_utils import get_num_classes, register_coco_dataset
|
|
8
|
+
from .model_catalog import get_config_path
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def build_cfg(
|
|
12
|
+
model_name,
|
|
13
|
+
train_images_dir,
|
|
14
|
+
train_json,
|
|
15
|
+
output_dir,
|
|
16
|
+
lr=0.0025,
|
|
17
|
+
max_iter=1000,
|
|
18
|
+
batch_size=2,
|
|
19
|
+
val_images_dir=None,
|
|
20
|
+
val_json=None,
|
|
21
|
+
resume=False,
|
|
22
|
+
weights_path=None,
|
|
23
|
+
):
|
|
24
|
+
"""Build a Detectron2 CfgNode from user selections.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
model_name: Key from MODEL_ZOO (e.g. "Faster R-CNN (R50-FPN)")
|
|
28
|
+
train_images_dir: Path to training images directory
|
|
29
|
+
train_json: Path to COCO JSON annotations for training
|
|
30
|
+
output_dir: Directory to save checkpoints and logs
|
|
31
|
+
lr: Base learning rate
|
|
32
|
+
max_iter: Maximum training iterations
|
|
33
|
+
batch_size: Images per batch
|
|
34
|
+
val_images_dir: Optional path to validation images directory
|
|
35
|
+
val_json: Optional path to validation COCO JSON
|
|
36
|
+
resume: Whether to resume from last checkpoint in output_dir
|
|
37
|
+
weights_path: Optional path to custom weights (.pth). If None, uses model zoo pretrained.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
CfgNode ready for training.
|
|
41
|
+
"""
|
|
42
|
+
config_path = get_config_path(model_name)
|
|
43
|
+
num_classes = get_num_classes(train_json)
|
|
44
|
+
|
|
45
|
+
# Register datasets
|
|
46
|
+
train_dataset_name = "alchemy_train"
|
|
47
|
+
register_coco_dataset(train_dataset_name, train_json, train_images_dir)
|
|
48
|
+
|
|
49
|
+
val_dataset_name = None
|
|
50
|
+
if val_json and val_images_dir:
|
|
51
|
+
val_dataset_name = "alchemy_val"
|
|
52
|
+
register_coco_dataset(val_dataset_name, val_json, val_images_dir)
|
|
53
|
+
|
|
54
|
+
cfg = get_cfg()
|
|
55
|
+
cfg.merge_from_file(model_zoo.get_config_file(config_path))
|
|
56
|
+
|
|
57
|
+
# Datasets
|
|
58
|
+
cfg.DATASETS.TRAIN = (train_dataset_name,)
|
|
59
|
+
cfg.DATASETS.TEST = (val_dataset_name,) if val_dataset_name else ()
|
|
60
|
+
|
|
61
|
+
# Dataloader
|
|
62
|
+
cfg.DATALOADER.NUM_WORKERS = 2
|
|
63
|
+
|
|
64
|
+
# Model weights
|
|
65
|
+
if weights_path:
|
|
66
|
+
cfg.MODEL.WEIGHTS = weights_path
|
|
67
|
+
else:
|
|
68
|
+
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(config_path)
|
|
69
|
+
|
|
70
|
+
# Solver
|
|
71
|
+
cfg.SOLVER.IMS_PER_BATCH = batch_size
|
|
72
|
+
cfg.SOLVER.BASE_LR = lr
|
|
73
|
+
cfg.SOLVER.MAX_ITER = max_iter
|
|
74
|
+
cfg.SOLVER.STEPS = [] # No LR decay for simplicity
|
|
75
|
+
cfg.SOLVER.CHECKPOINT_PERIOD = max(max_iter // 5, 100)
|
|
76
|
+
|
|
77
|
+
# Number of classes
|
|
78
|
+
if hasattr(cfg.MODEL, "ROI_HEADS"):
|
|
79
|
+
cfg.MODEL.ROI_HEADS.NUM_CLASSES = num_classes
|
|
80
|
+
if hasattr(cfg.MODEL, "RETINANET"):
|
|
81
|
+
cfg.MODEL.RETINANET.NUM_CLASSES = num_classes
|
|
82
|
+
|
|
83
|
+
# Device
|
|
84
|
+
cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
85
|
+
|
|
86
|
+
# Output
|
|
87
|
+
cfg.OUTPUT_DIR = output_dir
|
|
88
|
+
|
|
89
|
+
cfg.freeze()
|
|
90
|
+
return cfg
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
"""COCO dataset registration helpers for Detectron2."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
|
7
|
+
from detectron2.data.datasets import register_coco_instances
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def register_coco_dataset(name, json_path, image_root):
|
|
11
|
+
"""Register a COCO-format dataset, skipping if already registered."""
|
|
12
|
+
if name in DatasetCatalog.list():
|
|
13
|
+
DatasetCatalog.remove(name)
|
|
14
|
+
MetadataCatalog.remove(name)
|
|
15
|
+
register_coco_instances(name, {}, json_path, image_root)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def validate_coco_json(json_path, image_root):
|
|
19
|
+
"""Validate a COCO JSON file. Returns (is_valid, error_message)."""
|
|
20
|
+
json_path = Path(json_path)
|
|
21
|
+
image_root = Path(image_root)
|
|
22
|
+
|
|
23
|
+
if not json_path.exists():
|
|
24
|
+
return False, f"Annotation file not found: {json_path}"
|
|
25
|
+
|
|
26
|
+
if not image_root.exists():
|
|
27
|
+
return False, f"Image directory not found: {image_root}"
|
|
28
|
+
|
|
29
|
+
try:
|
|
30
|
+
with open(json_path, "r") as f:
|
|
31
|
+
data = json.load(f)
|
|
32
|
+
except json.JSONDecodeError as e:
|
|
33
|
+
return False, f"Invalid JSON: {e}"
|
|
34
|
+
|
|
35
|
+
for key in ("images", "annotations", "categories"):
|
|
36
|
+
if key not in data:
|
|
37
|
+
return False, f"Missing required key '{key}' in COCO JSON"
|
|
38
|
+
|
|
39
|
+
if len(data["categories"]) == 0:
|
|
40
|
+
return False, "No categories found in COCO JSON"
|
|
41
|
+
|
|
42
|
+
if len(data["images"]) == 0:
|
|
43
|
+
return False, "No images found in COCO JSON"
|
|
44
|
+
|
|
45
|
+
# Check that at least some image files exist
|
|
46
|
+
found = 0
|
|
47
|
+
for img_info in data["images"][:10]:
|
|
48
|
+
img_file = image_root / img_info["file_name"]
|
|
49
|
+
if img_file.exists():
|
|
50
|
+
found += 1
|
|
51
|
+
|
|
52
|
+
if found == 0:
|
|
53
|
+
return False, "None of the first 10 image files were found in the image directory"
|
|
54
|
+
|
|
55
|
+
return True, f"Valid COCO dataset: {len(data['images'])} images, {len(data['categories'])} categories"
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def get_num_classes(json_path):
|
|
59
|
+
"""Return the number of categories in a COCO JSON file."""
|
|
60
|
+
with open(json_path, "r") as f:
|
|
61
|
+
data = json.load(f)
|
|
62
|
+
return len(data["categories"])
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""Inference wrapper around Detectron2's DefaultPredictor."""
|
|
2
|
+
|
|
3
|
+
import cv2
|
|
4
|
+
from detectron2.config import get_cfg
|
|
5
|
+
from detectron2.data import MetadataCatalog
|
|
6
|
+
from detectron2.engine import DefaultPredictor
|
|
7
|
+
from detectron2.utils.visualizer import ColorMode, Visualizer
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def load_predictor(config_yaml_path, weights_path, threshold=0.5):
|
|
11
|
+
"""Load a Detectron2 predictor from a saved config + weights pair.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
config_yaml_path: Path to the saved config.yaml file.
|
|
15
|
+
weights_path: Path to the .pth weights file.
|
|
16
|
+
threshold: Confidence threshold for predictions.
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
(DefaultPredictor, CfgNode)
|
|
20
|
+
"""
|
|
21
|
+
cfg = get_cfg()
|
|
22
|
+
cfg.merge_from_file(config_yaml_path)
|
|
23
|
+
cfg.MODEL.WEIGHTS = weights_path
|
|
24
|
+
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = threshold
|
|
25
|
+
if hasattr(cfg.MODEL, "RETINANET"):
|
|
26
|
+
cfg.MODEL.RETINANET.SCORE_THRESH_TEST = threshold
|
|
27
|
+
cfg.freeze()
|
|
28
|
+
|
|
29
|
+
predictor = DefaultPredictor(cfg)
|
|
30
|
+
return predictor, cfg
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def run_inference_single(predictor, image_path):
|
|
34
|
+
"""Run inference on a single image.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
predictor: A DefaultPredictor instance.
|
|
38
|
+
image_path: Path to the image file.
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
(original_image_bgr, instances) where instances is the prediction output.
|
|
42
|
+
"""
|
|
43
|
+
img = cv2.imread(str(image_path))
|
|
44
|
+
if img is None:
|
|
45
|
+
raise ValueError(f"Could not read image: {image_path}")
|
|
46
|
+
outputs = predictor(img)
|
|
47
|
+
return img, outputs["instances"].to("cpu")
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def visualize_predictions(image_bgr, instances, metadata=None):
|
|
51
|
+
"""Draw predictions on an image using Detectron2's Visualizer.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
image_bgr: Original image in BGR format (numpy array).
|
|
55
|
+
instances: Detectron2 Instances object (on CPU).
|
|
56
|
+
metadata: Optional MetadataCatalog metadata for class names.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
Annotated image as RGB numpy array.
|
|
60
|
+
"""
|
|
61
|
+
image_rgb = image_bgr[:, :, ::-1]
|
|
62
|
+
if metadata is None:
|
|
63
|
+
metadata = MetadataCatalog.get("__empty")
|
|
64
|
+
|
|
65
|
+
v = Visualizer(image_rgb, metadata=metadata, scale=1.0, instance_mode=ColorMode.IMAGE)
|
|
66
|
+
out = v.draw_instance_predictions(instances)
|
|
67
|
+
return out.get_image()
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""Available Detectron2 model zoo entries."""
|
|
2
|
+
|
|
3
|
+
# Maps user-friendly name -> (model_zoo config path, task type)
|
|
4
|
+
MODEL_ZOO = {
|
|
5
|
+
# Object Detection
|
|
6
|
+
"Faster R-CNN (R50-FPN)": {
|
|
7
|
+
"config": "COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml",
|
|
8
|
+
"task": "detection",
|
|
9
|
+
},
|
|
10
|
+
"Faster R-CNN (R101-FPN)": {
|
|
11
|
+
"config": "COCO-Detection/faster_rcnn_R_101_FPN_3x.yaml",
|
|
12
|
+
"task": "detection",
|
|
13
|
+
},
|
|
14
|
+
"RetinaNet (R50-FPN)": {
|
|
15
|
+
"config": "COCO-Detection/retinanet_R_50_FPN_3x.yaml",
|
|
16
|
+
"task": "detection",
|
|
17
|
+
},
|
|
18
|
+
"RetinaNet (R101-FPN)": {
|
|
19
|
+
"config": "COCO-Detection/retinanet_R_101_FPN_3x.yaml",
|
|
20
|
+
"task": "detection",
|
|
21
|
+
},
|
|
22
|
+
# Instance Segmentation
|
|
23
|
+
"Mask R-CNN (R50-FPN)": {
|
|
24
|
+
"config": "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml",
|
|
25
|
+
"task": "instance_segmentation",
|
|
26
|
+
},
|
|
27
|
+
"Mask R-CNN (R101-FPN)": {
|
|
28
|
+
"config": "COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml",
|
|
29
|
+
"task": "instance_segmentation",
|
|
30
|
+
},
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def get_model_names():
|
|
35
|
+
"""Return list of all available model names."""
|
|
36
|
+
return list(MODEL_ZOO.keys())
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def get_detection_models():
|
|
40
|
+
"""Return names of detection-only models."""
|
|
41
|
+
return [k for k, v in MODEL_ZOO.items() if v["task"] == "detection"]
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def get_segmentation_models():
|
|
45
|
+
"""Return names of instance segmentation models."""
|
|
46
|
+
return [k for k, v in MODEL_ZOO.items() if v["task"] == "instance_segmentation"]
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def get_config_path(model_name):
|
|
50
|
+
"""Return the model zoo config path for a given model name."""
|
|
51
|
+
return MODEL_ZOO[model_name]["config"]
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def get_task(model_name):
|
|
55
|
+
"""Return the task type ('detection' or 'instance_segmentation')."""
|
|
56
|
+
return MODEL_ZOO[model_name]["task"]
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""Custom Detectron2 trainer with metric emission for the GUI."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import sys
|
|
5
|
+
|
|
6
|
+
from detectron2.engine import DefaultTrainer, HookBase
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class MetricEmitterHook(HookBase):
|
|
10
|
+
"""Hook that pushes training metrics to a multiprocessing.Queue."""
|
|
11
|
+
|
|
12
|
+
def __init__(self, queue, stop_event, period=20):
|
|
13
|
+
"""
|
|
14
|
+
Args:
|
|
15
|
+
queue: multiprocessing.Queue to push metric dicts into.
|
|
16
|
+
stop_event: multiprocessing.Event; when set, training stops.
|
|
17
|
+
period: Emit metrics every N iterations.
|
|
18
|
+
"""
|
|
19
|
+
self._queue = queue
|
|
20
|
+
self._stop_event = stop_event
|
|
21
|
+
self._period = period
|
|
22
|
+
|
|
23
|
+
def after_step(self):
|
|
24
|
+
# Check for stop request
|
|
25
|
+
if self._stop_event.is_set():
|
|
26
|
+
self._queue.put({"type": "log", "msg": "Training stopped by user."})
|
|
27
|
+
sys.exit(0)
|
|
28
|
+
|
|
29
|
+
iter_num = self.trainer.iter
|
|
30
|
+
if (iter_num + 1) % self._period == 0:
|
|
31
|
+
storage = self.trainer.storage
|
|
32
|
+
metrics = {}
|
|
33
|
+
for k, v in storage.latest().items():
|
|
34
|
+
if isinstance(v, tuple):
|
|
35
|
+
metrics[k] = v[0]
|
|
36
|
+
metrics["iter"] = iter_num + 1
|
|
37
|
+
metrics["type"] = "metrics"
|
|
38
|
+
self._queue.put(metrics)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class QueueLogHandler(logging.Handler):
|
|
42
|
+
"""Logging handler that sends log records to a multiprocessing.Queue."""
|
|
43
|
+
|
|
44
|
+
def __init__(self, queue):
|
|
45
|
+
super().__init__()
|
|
46
|
+
self._queue = queue
|
|
47
|
+
|
|
48
|
+
def emit(self, record):
|
|
49
|
+
try:
|
|
50
|
+
msg = self.format(record)
|
|
51
|
+
self._queue.put({"type": "log", "msg": msg})
|
|
52
|
+
except Exception:
|
|
53
|
+
pass
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class AlchemyTrainer(DefaultTrainer):
|
|
57
|
+
"""Detectron2 trainer that emits metrics to a queue for GUI consumption."""
|
|
58
|
+
|
|
59
|
+
def __init__(self, cfg, metric_queue, stop_event):
|
|
60
|
+
self._metric_queue = metric_queue
|
|
61
|
+
self._stop_event = stop_event
|
|
62
|
+
super().__init__(cfg)
|
|
63
|
+
|
|
64
|
+
def build_hooks(self):
|
|
65
|
+
hooks = super().build_hooks()
|
|
66
|
+
hooks.append(MetricEmitterHook(self._metric_queue, self._stop_event))
|
|
67
|
+
return hooks
|
|
File without changes
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
"""File dialogs for model save/load and dataset selection."""
|
|
2
|
+
|
|
3
|
+
import shutil
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
from PyQt6.QtWidgets import QFileDialog, QMessageBox
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def browse_directory(parent, title="Select Directory", start_dir=""):
|
|
10
|
+
"""Open a directory picker dialog. Returns path string or empty string."""
|
|
11
|
+
path = QFileDialog.getExistingDirectory(parent, title, start_dir)
|
|
12
|
+
return path
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def browse_file(parent, title="Select File", start_dir="", filter_str="All Files (*)"):
|
|
16
|
+
"""Open a file picker dialog. Returns path string or empty string."""
|
|
17
|
+
path, _ = QFileDialog.getOpenFileName(parent, title, start_dir, filter_str)
|
|
18
|
+
return path
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def save_model_dialog(parent, output_dir):
|
|
22
|
+
"""Save trained model (.pth + config.yaml) to a user-chosen directory.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
parent: Parent widget.
|
|
26
|
+
output_dir: The training output directory containing model_final.pth and config.yaml.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
Path to the saved directory, or None if cancelled.
|
|
30
|
+
"""
|
|
31
|
+
output_path = Path(output_dir)
|
|
32
|
+
weights_file = output_path / "model_final.pth"
|
|
33
|
+
config_file = output_path / "config.yaml"
|
|
34
|
+
|
|
35
|
+
if not weights_file.exists():
|
|
36
|
+
QMessageBox.warning(parent, "Save Model", "No model_final.pth found. Train a model first.")
|
|
37
|
+
return None
|
|
38
|
+
|
|
39
|
+
dest_dir = QFileDialog.getExistingDirectory(parent, "Save Model To Directory")
|
|
40
|
+
if not dest_dir:
|
|
41
|
+
return None
|
|
42
|
+
|
|
43
|
+
dest = Path(dest_dir)
|
|
44
|
+
try:
|
|
45
|
+
shutil.copy2(weights_file, dest / "model_final.pth")
|
|
46
|
+
if config_file.exists():
|
|
47
|
+
shutil.copy2(config_file, dest / "config.yaml")
|
|
48
|
+
QMessageBox.information(parent, "Save Model", f"Model saved to:\n{dest}")
|
|
49
|
+
return str(dest)
|
|
50
|
+
except Exception as e:
|
|
51
|
+
QMessageBox.critical(parent, "Save Model Error", str(e))
|
|
52
|
+
return None
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def load_model_dialog(parent):
|
|
56
|
+
"""Open dialogs to select a model weights file and its config.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
(config_yaml_path, weights_path) or (None, None) if cancelled.
|
|
60
|
+
"""
|
|
61
|
+
weights_path, _ = QFileDialog.getOpenFileName(
|
|
62
|
+
parent, "Select Model Weights", "", "PyTorch Weights (*.pth);;All Files (*)"
|
|
63
|
+
)
|
|
64
|
+
if not weights_path:
|
|
65
|
+
return None, None
|
|
66
|
+
|
|
67
|
+
# Try to auto-find config.yaml in the same directory
|
|
68
|
+
weights_dir = Path(weights_path).parent
|
|
69
|
+
auto_config = weights_dir / "config.yaml"
|
|
70
|
+
|
|
71
|
+
if auto_config.exists():
|
|
72
|
+
return str(auto_config), weights_path
|
|
73
|
+
|
|
74
|
+
# Ask user to select config manually
|
|
75
|
+
config_path, _ = QFileDialog.getOpenFileName(
|
|
76
|
+
parent, "Select Config YAML", str(weights_dir), "YAML Files (*.yaml *.yml);;All Files (*)"
|
|
77
|
+
)
|
|
78
|
+
if not config_path:
|
|
79
|
+
return None, None
|
|
80
|
+
|
|
81
|
+
return config_path, weights_path
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
"""Image viewer widget that displays images with detection overlays."""
|
|
2
|
+
|
|
3
|
+
from PyQt6.QtCore import Qt
|
|
4
|
+
from PyQt6.QtGui import QImage, QPixmap
|
|
5
|
+
from PyQt6.QtWidgets import QLabel, QScrollArea, QVBoxLayout, QWidget
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ImageViewer(QWidget):
|
|
9
|
+
"""Scrollable image display widget."""
|
|
10
|
+
|
|
11
|
+
def __init__(self, parent=None):
|
|
12
|
+
super().__init__(parent)
|
|
13
|
+
self._label = QLabel()
|
|
14
|
+
self._label.setAlignment(Qt.AlignmentFlag.AlignCenter)
|
|
15
|
+
|
|
16
|
+
scroll = QScrollArea()
|
|
17
|
+
scroll.setWidget(self._label)
|
|
18
|
+
scroll.setWidgetResizable(True)
|
|
19
|
+
|
|
20
|
+
layout = QVBoxLayout(self)
|
|
21
|
+
layout.setContentsMargins(0, 0, 0, 0)
|
|
22
|
+
layout.addWidget(scroll)
|
|
23
|
+
|
|
24
|
+
self._current_pixmap = None
|
|
25
|
+
|
|
26
|
+
def set_image_rgb(self, image_rgb):
|
|
27
|
+
"""Display an RGB numpy array image.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
image_rgb: numpy array of shape (H, W, 3) with dtype uint8.
|
|
31
|
+
"""
|
|
32
|
+
if image_rgb is None:
|
|
33
|
+
self._label.clear()
|
|
34
|
+
self._current_pixmap = None
|
|
35
|
+
return
|
|
36
|
+
|
|
37
|
+
h, w, ch = image_rgb.shape
|
|
38
|
+
bytes_per_line = ch * w
|
|
39
|
+
qimg = QImage(image_rgb.data, w, h, bytes_per_line, QImage.Format.Format_RGB888)
|
|
40
|
+
self._current_pixmap = QPixmap.fromImage(qimg)
|
|
41
|
+
self._update_display()
|
|
42
|
+
|
|
43
|
+
def _update_display(self):
|
|
44
|
+
"""Scale pixmap to fit the label while keeping aspect ratio."""
|
|
45
|
+
if self._current_pixmap is None:
|
|
46
|
+
return
|
|
47
|
+
scaled = self._current_pixmap.scaled(
|
|
48
|
+
self._label.size(),
|
|
49
|
+
Qt.AspectRatioMode.KeepAspectRatio,
|
|
50
|
+
Qt.TransformationMode.SmoothTransformation,
|
|
51
|
+
)
|
|
52
|
+
self._label.setPixmap(scaled)
|
|
53
|
+
|
|
54
|
+
def resizeEvent(self, event):
|
|
55
|
+
super().resizeEvent(event)
|
|
56
|
+
self._update_display()
|
|
57
|
+
|
|
58
|
+
def clear_image(self):
|
|
59
|
+
"""Clear the displayed image."""
|
|
60
|
+
self._label.clear()
|
|
61
|
+
self._current_pixmap = None
|