alchemydetect 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
alchemydetect/app.py ADDED
@@ -0,0 +1,23 @@
1
+ """Entry point for the AlchemyDetect application."""
2
+
3
+ import multiprocessing
4
+ import sys
5
+
6
+ from PyQt6.QtWidgets import QApplication
7
+
8
+ from alchemydetect.gui.main_window import MainWindow
9
+
10
+
11
+ def main():
12
+ app = QApplication(sys.argv)
13
+ app.setApplicationName("AlchemyDetect")
14
+
15
+ window = MainWindow()
16
+ window.show()
17
+
18
+ sys.exit(app.exec())
19
+
20
+
21
+ if __name__ == "__main__":
22
+ multiprocessing.freeze_support()
23
+ main()
File without changes
@@ -0,0 +1,90 @@
1
+ """Build Detectron2 config from user-specified parameters."""
2
+
3
+ import torch
4
+ from detectron2 import model_zoo
5
+ from detectron2.config import get_cfg
6
+
7
+ from .dataset_utils import get_num_classes, register_coco_dataset
8
+ from .model_catalog import get_config_path
9
+
10
+
11
+ def build_cfg(
12
+ model_name,
13
+ train_images_dir,
14
+ train_json,
15
+ output_dir,
16
+ lr=0.0025,
17
+ max_iter=1000,
18
+ batch_size=2,
19
+ val_images_dir=None,
20
+ val_json=None,
21
+ resume=False,
22
+ weights_path=None,
23
+ ):
24
+ """Build a Detectron2 CfgNode from user selections.
25
+
26
+ Args:
27
+ model_name: Key from MODEL_ZOO (e.g. "Faster R-CNN (R50-FPN)")
28
+ train_images_dir: Path to training images directory
29
+ train_json: Path to COCO JSON annotations for training
30
+ output_dir: Directory to save checkpoints and logs
31
+ lr: Base learning rate
32
+ max_iter: Maximum training iterations
33
+ batch_size: Images per batch
34
+ val_images_dir: Optional path to validation images directory
35
+ val_json: Optional path to validation COCO JSON
36
+ resume: Whether to resume from last checkpoint in output_dir
37
+ weights_path: Optional path to custom weights (.pth). If None, uses model zoo pretrained.
38
+
39
+ Returns:
40
+ CfgNode ready for training.
41
+ """
42
+ config_path = get_config_path(model_name)
43
+ num_classes = get_num_classes(train_json)
44
+
45
+ # Register datasets
46
+ train_dataset_name = "alchemy_train"
47
+ register_coco_dataset(train_dataset_name, train_json, train_images_dir)
48
+
49
+ val_dataset_name = None
50
+ if val_json and val_images_dir:
51
+ val_dataset_name = "alchemy_val"
52
+ register_coco_dataset(val_dataset_name, val_json, val_images_dir)
53
+
54
+ cfg = get_cfg()
55
+ cfg.merge_from_file(model_zoo.get_config_file(config_path))
56
+
57
+ # Datasets
58
+ cfg.DATASETS.TRAIN = (train_dataset_name,)
59
+ cfg.DATASETS.TEST = (val_dataset_name,) if val_dataset_name else ()
60
+
61
+ # Dataloader
62
+ cfg.DATALOADER.NUM_WORKERS = 2
63
+
64
+ # Model weights
65
+ if weights_path:
66
+ cfg.MODEL.WEIGHTS = weights_path
67
+ else:
68
+ cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(config_path)
69
+
70
+ # Solver
71
+ cfg.SOLVER.IMS_PER_BATCH = batch_size
72
+ cfg.SOLVER.BASE_LR = lr
73
+ cfg.SOLVER.MAX_ITER = max_iter
74
+ cfg.SOLVER.STEPS = [] # No LR decay for simplicity
75
+ cfg.SOLVER.CHECKPOINT_PERIOD = max(max_iter // 5, 100)
76
+
77
+ # Number of classes
78
+ if hasattr(cfg.MODEL, "ROI_HEADS"):
79
+ cfg.MODEL.ROI_HEADS.NUM_CLASSES = num_classes
80
+ if hasattr(cfg.MODEL, "RETINANET"):
81
+ cfg.MODEL.RETINANET.NUM_CLASSES = num_classes
82
+
83
+ # Device
84
+ cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
85
+
86
+ # Output
87
+ cfg.OUTPUT_DIR = output_dir
88
+
89
+ cfg.freeze()
90
+ return cfg
@@ -0,0 +1,62 @@
1
+ """COCO dataset registration helpers for Detectron2."""
2
+
3
+ import json
4
+ from pathlib import Path
5
+
6
+ from detectron2.data import DatasetCatalog, MetadataCatalog
7
+ from detectron2.data.datasets import register_coco_instances
8
+
9
+
10
+ def register_coco_dataset(name, json_path, image_root):
11
+ """Register a COCO-format dataset, skipping if already registered."""
12
+ if name in DatasetCatalog.list():
13
+ DatasetCatalog.remove(name)
14
+ MetadataCatalog.remove(name)
15
+ register_coco_instances(name, {}, json_path, image_root)
16
+
17
+
18
+ def validate_coco_json(json_path, image_root):
19
+ """Validate a COCO JSON file. Returns (is_valid, error_message)."""
20
+ json_path = Path(json_path)
21
+ image_root = Path(image_root)
22
+
23
+ if not json_path.exists():
24
+ return False, f"Annotation file not found: {json_path}"
25
+
26
+ if not image_root.exists():
27
+ return False, f"Image directory not found: {image_root}"
28
+
29
+ try:
30
+ with open(json_path, "r") as f:
31
+ data = json.load(f)
32
+ except json.JSONDecodeError as e:
33
+ return False, f"Invalid JSON: {e}"
34
+
35
+ for key in ("images", "annotations", "categories"):
36
+ if key not in data:
37
+ return False, f"Missing required key '{key}' in COCO JSON"
38
+
39
+ if len(data["categories"]) == 0:
40
+ return False, "No categories found in COCO JSON"
41
+
42
+ if len(data["images"]) == 0:
43
+ return False, "No images found in COCO JSON"
44
+
45
+ # Check that at least some image files exist
46
+ found = 0
47
+ for img_info in data["images"][:10]:
48
+ img_file = image_root / img_info["file_name"]
49
+ if img_file.exists():
50
+ found += 1
51
+
52
+ if found == 0:
53
+ return False, "None of the first 10 image files were found in the image directory"
54
+
55
+ return True, f"Valid COCO dataset: {len(data['images'])} images, {len(data['categories'])} categories"
56
+
57
+
58
+ def get_num_classes(json_path):
59
+ """Return the number of categories in a COCO JSON file."""
60
+ with open(json_path, "r") as f:
61
+ data = json.load(f)
62
+ return len(data["categories"])
@@ -0,0 +1,67 @@
1
+ """Inference wrapper around Detectron2's DefaultPredictor."""
2
+
3
+ import cv2
4
+ from detectron2.config import get_cfg
5
+ from detectron2.data import MetadataCatalog
6
+ from detectron2.engine import DefaultPredictor
7
+ from detectron2.utils.visualizer import ColorMode, Visualizer
8
+
9
+
10
+ def load_predictor(config_yaml_path, weights_path, threshold=0.5):
11
+ """Load a Detectron2 predictor from a saved config + weights pair.
12
+
13
+ Args:
14
+ config_yaml_path: Path to the saved config.yaml file.
15
+ weights_path: Path to the .pth weights file.
16
+ threshold: Confidence threshold for predictions.
17
+
18
+ Returns:
19
+ (DefaultPredictor, CfgNode)
20
+ """
21
+ cfg = get_cfg()
22
+ cfg.merge_from_file(config_yaml_path)
23
+ cfg.MODEL.WEIGHTS = weights_path
24
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = threshold
25
+ if hasattr(cfg.MODEL, "RETINANET"):
26
+ cfg.MODEL.RETINANET.SCORE_THRESH_TEST = threshold
27
+ cfg.freeze()
28
+
29
+ predictor = DefaultPredictor(cfg)
30
+ return predictor, cfg
31
+
32
+
33
+ def run_inference_single(predictor, image_path):
34
+ """Run inference on a single image.
35
+
36
+ Args:
37
+ predictor: A DefaultPredictor instance.
38
+ image_path: Path to the image file.
39
+
40
+ Returns:
41
+ (original_image_bgr, instances) where instances is the prediction output.
42
+ """
43
+ img = cv2.imread(str(image_path))
44
+ if img is None:
45
+ raise ValueError(f"Could not read image: {image_path}")
46
+ outputs = predictor(img)
47
+ return img, outputs["instances"].to("cpu")
48
+
49
+
50
+ def visualize_predictions(image_bgr, instances, metadata=None):
51
+ """Draw predictions on an image using Detectron2's Visualizer.
52
+
53
+ Args:
54
+ image_bgr: Original image in BGR format (numpy array).
55
+ instances: Detectron2 Instances object (on CPU).
56
+ metadata: Optional MetadataCatalog metadata for class names.
57
+
58
+ Returns:
59
+ Annotated image as RGB numpy array.
60
+ """
61
+ image_rgb = image_bgr[:, :, ::-1]
62
+ if metadata is None:
63
+ metadata = MetadataCatalog.get("__empty")
64
+
65
+ v = Visualizer(image_rgb, metadata=metadata, scale=1.0, instance_mode=ColorMode.IMAGE)
66
+ out = v.draw_instance_predictions(instances)
67
+ return out.get_image()
@@ -0,0 +1,56 @@
1
+ """Available Detectron2 model zoo entries."""
2
+
3
+ # Maps user-friendly name -> (model_zoo config path, task type)
4
+ MODEL_ZOO = {
5
+ # Object Detection
6
+ "Faster R-CNN (R50-FPN)": {
7
+ "config": "COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml",
8
+ "task": "detection",
9
+ },
10
+ "Faster R-CNN (R101-FPN)": {
11
+ "config": "COCO-Detection/faster_rcnn_R_101_FPN_3x.yaml",
12
+ "task": "detection",
13
+ },
14
+ "RetinaNet (R50-FPN)": {
15
+ "config": "COCO-Detection/retinanet_R_50_FPN_3x.yaml",
16
+ "task": "detection",
17
+ },
18
+ "RetinaNet (R101-FPN)": {
19
+ "config": "COCO-Detection/retinanet_R_101_FPN_3x.yaml",
20
+ "task": "detection",
21
+ },
22
+ # Instance Segmentation
23
+ "Mask R-CNN (R50-FPN)": {
24
+ "config": "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml",
25
+ "task": "instance_segmentation",
26
+ },
27
+ "Mask R-CNN (R101-FPN)": {
28
+ "config": "COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml",
29
+ "task": "instance_segmentation",
30
+ },
31
+ }
32
+
33
+
34
+ def get_model_names():
35
+ """Return list of all available model names."""
36
+ return list(MODEL_ZOO.keys())
37
+
38
+
39
+ def get_detection_models():
40
+ """Return names of detection-only models."""
41
+ return [k for k, v in MODEL_ZOO.items() if v["task"] == "detection"]
42
+
43
+
44
+ def get_segmentation_models():
45
+ """Return names of instance segmentation models."""
46
+ return [k for k, v in MODEL_ZOO.items() if v["task"] == "instance_segmentation"]
47
+
48
+
49
+ def get_config_path(model_name):
50
+ """Return the model zoo config path for a given model name."""
51
+ return MODEL_ZOO[model_name]["config"]
52
+
53
+
54
+ def get_task(model_name):
55
+ """Return the task type ('detection' or 'instance_segmentation')."""
56
+ return MODEL_ZOO[model_name]["task"]
@@ -0,0 +1,67 @@
1
+ """Custom Detectron2 trainer with metric emission for the GUI."""
2
+
3
+ import logging
4
+ import sys
5
+
6
+ from detectron2.engine import DefaultTrainer, HookBase
7
+
8
+
9
+ class MetricEmitterHook(HookBase):
10
+ """Hook that pushes training metrics to a multiprocessing.Queue."""
11
+
12
+ def __init__(self, queue, stop_event, period=20):
13
+ """
14
+ Args:
15
+ queue: multiprocessing.Queue to push metric dicts into.
16
+ stop_event: multiprocessing.Event; when set, training stops.
17
+ period: Emit metrics every N iterations.
18
+ """
19
+ self._queue = queue
20
+ self._stop_event = stop_event
21
+ self._period = period
22
+
23
+ def after_step(self):
24
+ # Check for stop request
25
+ if self._stop_event.is_set():
26
+ self._queue.put({"type": "log", "msg": "Training stopped by user."})
27
+ sys.exit(0)
28
+
29
+ iter_num = self.trainer.iter
30
+ if (iter_num + 1) % self._period == 0:
31
+ storage = self.trainer.storage
32
+ metrics = {}
33
+ for k, v in storage.latest().items():
34
+ if isinstance(v, tuple):
35
+ metrics[k] = v[0]
36
+ metrics["iter"] = iter_num + 1
37
+ metrics["type"] = "metrics"
38
+ self._queue.put(metrics)
39
+
40
+
41
+ class QueueLogHandler(logging.Handler):
42
+ """Logging handler that sends log records to a multiprocessing.Queue."""
43
+
44
+ def __init__(self, queue):
45
+ super().__init__()
46
+ self._queue = queue
47
+
48
+ def emit(self, record):
49
+ try:
50
+ msg = self.format(record)
51
+ self._queue.put({"type": "log", "msg": msg})
52
+ except Exception:
53
+ pass
54
+
55
+
56
+ class AlchemyTrainer(DefaultTrainer):
57
+ """Detectron2 trainer that emits metrics to a queue for GUI consumption."""
58
+
59
+ def __init__(self, cfg, metric_queue, stop_event):
60
+ self._metric_queue = metric_queue
61
+ self._stop_event = stop_event
62
+ super().__init__(cfg)
63
+
64
+ def build_hooks(self):
65
+ hooks = super().build_hooks()
66
+ hooks.append(MetricEmitterHook(self._metric_queue, self._stop_event))
67
+ return hooks
File without changes
@@ -0,0 +1,81 @@
1
+ """File dialogs for model save/load and dataset selection."""
2
+
3
+ import shutil
4
+ from pathlib import Path
5
+
6
+ from PyQt6.QtWidgets import QFileDialog, QMessageBox
7
+
8
+
9
+ def browse_directory(parent, title="Select Directory", start_dir=""):
10
+ """Open a directory picker dialog. Returns path string or empty string."""
11
+ path = QFileDialog.getExistingDirectory(parent, title, start_dir)
12
+ return path
13
+
14
+
15
+ def browse_file(parent, title="Select File", start_dir="", filter_str="All Files (*)"):
16
+ """Open a file picker dialog. Returns path string or empty string."""
17
+ path, _ = QFileDialog.getOpenFileName(parent, title, start_dir, filter_str)
18
+ return path
19
+
20
+
21
+ def save_model_dialog(parent, output_dir):
22
+ """Save trained model (.pth + config.yaml) to a user-chosen directory.
23
+
24
+ Args:
25
+ parent: Parent widget.
26
+ output_dir: The training output directory containing model_final.pth and config.yaml.
27
+
28
+ Returns:
29
+ Path to the saved directory, or None if cancelled.
30
+ """
31
+ output_path = Path(output_dir)
32
+ weights_file = output_path / "model_final.pth"
33
+ config_file = output_path / "config.yaml"
34
+
35
+ if not weights_file.exists():
36
+ QMessageBox.warning(parent, "Save Model", "No model_final.pth found. Train a model first.")
37
+ return None
38
+
39
+ dest_dir = QFileDialog.getExistingDirectory(parent, "Save Model To Directory")
40
+ if not dest_dir:
41
+ return None
42
+
43
+ dest = Path(dest_dir)
44
+ try:
45
+ shutil.copy2(weights_file, dest / "model_final.pth")
46
+ if config_file.exists():
47
+ shutil.copy2(config_file, dest / "config.yaml")
48
+ QMessageBox.information(parent, "Save Model", f"Model saved to:\n{dest}")
49
+ return str(dest)
50
+ except Exception as e:
51
+ QMessageBox.critical(parent, "Save Model Error", str(e))
52
+ return None
53
+
54
+
55
+ def load_model_dialog(parent):
56
+ """Open dialogs to select a model weights file and its config.
57
+
58
+ Returns:
59
+ (config_yaml_path, weights_path) or (None, None) if cancelled.
60
+ """
61
+ weights_path, _ = QFileDialog.getOpenFileName(
62
+ parent, "Select Model Weights", "", "PyTorch Weights (*.pth);;All Files (*)"
63
+ )
64
+ if not weights_path:
65
+ return None, None
66
+
67
+ # Try to auto-find config.yaml in the same directory
68
+ weights_dir = Path(weights_path).parent
69
+ auto_config = weights_dir / "config.yaml"
70
+
71
+ if auto_config.exists():
72
+ return str(auto_config), weights_path
73
+
74
+ # Ask user to select config manually
75
+ config_path, _ = QFileDialog.getOpenFileName(
76
+ parent, "Select Config YAML", str(weights_dir), "YAML Files (*.yaml *.yml);;All Files (*)"
77
+ )
78
+ if not config_path:
79
+ return None, None
80
+
81
+ return config_path, weights_path
@@ -0,0 +1,61 @@
1
+ """Image viewer widget that displays images with detection overlays."""
2
+
3
+ from PyQt6.QtCore import Qt
4
+ from PyQt6.QtGui import QImage, QPixmap
5
+ from PyQt6.QtWidgets import QLabel, QScrollArea, QVBoxLayout, QWidget
6
+
7
+
8
+ class ImageViewer(QWidget):
9
+ """Scrollable image display widget."""
10
+
11
+ def __init__(self, parent=None):
12
+ super().__init__(parent)
13
+ self._label = QLabel()
14
+ self._label.setAlignment(Qt.AlignmentFlag.AlignCenter)
15
+
16
+ scroll = QScrollArea()
17
+ scroll.setWidget(self._label)
18
+ scroll.setWidgetResizable(True)
19
+
20
+ layout = QVBoxLayout(self)
21
+ layout.setContentsMargins(0, 0, 0, 0)
22
+ layout.addWidget(scroll)
23
+
24
+ self._current_pixmap = None
25
+
26
+ def set_image_rgb(self, image_rgb):
27
+ """Display an RGB numpy array image.
28
+
29
+ Args:
30
+ image_rgb: numpy array of shape (H, W, 3) with dtype uint8.
31
+ """
32
+ if image_rgb is None:
33
+ self._label.clear()
34
+ self._current_pixmap = None
35
+ return
36
+
37
+ h, w, ch = image_rgb.shape
38
+ bytes_per_line = ch * w
39
+ qimg = QImage(image_rgb.data, w, h, bytes_per_line, QImage.Format.Format_RGB888)
40
+ self._current_pixmap = QPixmap.fromImage(qimg)
41
+ self._update_display()
42
+
43
+ def _update_display(self):
44
+ """Scale pixmap to fit the label while keeping aspect ratio."""
45
+ if self._current_pixmap is None:
46
+ return
47
+ scaled = self._current_pixmap.scaled(
48
+ self._label.size(),
49
+ Qt.AspectRatioMode.KeepAspectRatio,
50
+ Qt.TransformationMode.SmoothTransformation,
51
+ )
52
+ self._label.setPixmap(scaled)
53
+
54
+ def resizeEvent(self, event):
55
+ super().resizeEvent(event)
56
+ self._update_display()
57
+
58
+ def clear_image(self):
59
+ """Clear the displayed image."""
60
+ self._label.clear()
61
+ self._current_pixmap = None