bplusplus 2.0.1__tar.gz → 2.0.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: bplusplus
3
- Version: 2.0.1
3
+ Version: 2.0.4
4
4
  Summary: A simple method to create AI models for biodiversity, with collect and prepare pipeline
5
5
  License: MIT
6
6
  Author: Titus Venverloo
@@ -128,10 +128,13 @@ bplusplus.prepare(
128
128
  output_directory=PREPARED_DATA_DIR,
129
129
  img_size=640, # Target image size for training
130
130
  conf=0.6, # Detection confidence threshold (0-1)
131
- valid=0.1, # Validation split ratio (0-1), set to 0 for no validation
131
+ valid=0.1, # Validation split ratio (0-1), set to 0 for no validation
132
+ blur=None, # Gaussian blur as fraction of image size (0-1), None = disabled
132
133
  )
133
134
  ```
134
135
 
136
+ **Note:** The `blur` parameter applies Gaussian blur before resizing, which can help reduce noise. Values are relative to image size (e.g., `blur=0.01` means 1% of the smallest dimension). Supported image formats: JPG, JPEG, and PNG.
137
+
135
138
  #### Step 3: Train Model
136
139
  Train the hierarchical classification model on your prepared data. The model learns to identify family, genus, and species.
137
140
 
@@ -191,11 +194,15 @@ results = bplusplus.inference(
191
194
  output_dir=OUTPUT_DIR,
192
195
  fps=None, # None = process all frames
193
196
  backbone="resnet50", # Must match training
197
+ save_video=True, # Set to False to skip video rendering (only CSV output)
198
+ img_size=60, # Must match training
194
199
  )
195
200
 
196
201
  print(f"Detected {results['tracks']} tracks ({results['confirmed_tracks']} confirmed)")
197
202
  ```
198
203
 
204
+ **Note:** Set `save_video=False` to skip generating the annotated and debug videos, which speeds up processing when you only need the CSV detection data.
205
+
199
206
  **Custom Detection Configuration:**
200
207
 
201
208
  For advanced control over detection parameters, provide a YAML config file:
@@ -86,10 +86,13 @@ bplusplus.prepare(
86
86
  output_directory=PREPARED_DATA_DIR,
87
87
  img_size=640, # Target image size for training
88
88
  conf=0.6, # Detection confidence threshold (0-1)
89
- valid=0.1, # Validation split ratio (0-1), set to 0 for no validation
89
+ valid=0.1, # Validation split ratio (0-1), set to 0 for no validation
90
+ blur=None, # Gaussian blur as fraction of image size (0-1), None = disabled
90
91
  )
91
92
  ```
92
93
 
94
+ **Note:** The `blur` parameter applies Gaussian blur before resizing, which can help reduce noise. Values are relative to image size (e.g., `blur=0.01` means 1% of the smallest dimension). Supported image formats: JPG, JPEG, and PNG.
95
+
93
96
  #### Step 3: Train Model
94
97
  Train the hierarchical classification model on your prepared data. The model learns to identify family, genus, and species.
95
98
 
@@ -149,11 +152,15 @@ results = bplusplus.inference(
149
152
  output_dir=OUTPUT_DIR,
150
153
  fps=None, # None = process all frames
151
154
  backbone="resnet50", # Must match training
155
+ save_video=True, # Set to False to skip video rendering (only CSV output)
156
+ img_size=60, # Must match training
152
157
  )
153
158
 
154
159
  print(f"Detected {results['tracks']} tracks ({results['confirmed_tracks']} confirmed)")
155
160
  ```
156
161
 
162
+ **Note:** Set `save_video=False` to skip generating the annotated and debug videos, which speeds up processing when you only need the CSV detection data.
163
+
157
164
  **Custom Detection Configuration:**
158
165
 
159
166
  For advanced control over detection parameters, provide a YAML config file:
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "bplusplus"
3
- version = "2.0.1"
3
+ version = "2.0.4"
4
4
  description = "A simple method to create AI models for biodiversity, with collect and prepare pipeline"
5
5
  authors = ["Titus Venverloo <tvenver@mit.edu>", "Deniz Aydemir <deniz@aydemir.us>", "Orlando Closs <orlandocloss@pm.me>", "Ase Hatveit <aase@mit.edu>"]
6
6
  license = "MIT"
@@ -1,6 +1,6 @@
1
1
  """
2
- Insect Detection Backend Module
3
- ===============================
2
+ Detection Backend Module
3
+ ========================
4
4
 
5
5
  This module provides motion-based insect detection utilities used by the inference pipeline.
6
6
  It is NOT meant to be run directly - use inference.py instead.
@@ -18,7 +18,7 @@ import requests
18
18
  import logging
19
19
 
20
20
  from .tracker import InsectTracker
21
- from .insect_detector import (
21
+ from .detector import (
22
22
  DEFAULT_DETECTION_CONFIG,
23
23
  get_default_config,
24
24
  build_detection_params,
@@ -210,7 +210,7 @@ class HierarchicalInsectClassifier(nn.Module):
210
210
  def forward(self, x):
211
211
  features = self.backbone(x)
212
212
  return [branch(features) for branch in self.branches]
213
-
213
+
214
214
 
215
215
  # ============================================================================
216
216
  # VISUALIZATION
@@ -257,8 +257,8 @@ class FrameVisualizer:
257
257
  lines = [track_display]
258
258
 
259
259
  for level, conf_key in [("family", "family_confidence"),
260
- ("genus", "genus_confidence"),
261
- ("species", "species_confidence")]:
260
+ ("genus", "genus_confidence"),
261
+ ("species", "species_confidence")]:
262
262
  if detection_data.get(level):
263
263
  name = detection_data[level]
264
264
  conf = detection_data.get(conf_key, 0)
@@ -307,7 +307,7 @@ class VideoInferenceProcessor:
307
307
  and track-based prediction aggregation.
308
308
  """
309
309
 
310
- def __init__(self, species_list, hierarchical_model_path, params, backbone="resnet50"):
310
+ def __init__(self, species_list, hierarchical_model_path, params, backbone="resnet50", img_size=60):
311
311
  """
312
312
  Initialize the processor.
313
313
 
@@ -316,7 +316,9 @@ class VideoInferenceProcessor:
316
316
  hierarchical_model_path: Path to trained model weights
317
317
  params: Detection parameters dict
318
318
  backbone: ResNet backbone ('resnet18', 'resnet50', 'resnet101')
319
+ img_size: Image size for classification (should match training)
319
320
  """
321
+ self.img_size = img_size
320
322
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
321
323
  self.species_list = species_list
322
324
  self.params = params
@@ -354,8 +356,7 @@ class VideoInferenceProcessor:
354
356
  self.model.eval()
355
357
 
356
358
  self.transform = transforms.Compose([
357
- transforms.Resize((768, 768)),
358
- transforms.CenterCrop(640),
359
+ transforms.Resize((self.img_size, self.img_size)),
359
360
  transforms.ToTensor(),
360
361
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
361
362
  ])
@@ -463,13 +464,14 @@ class VideoInferenceProcessor:
463
464
 
464
465
  return fg_mask, frame_detections
465
466
 
466
- def classify_confirmed_tracks(self, video_path, confirmed_track_ids):
467
+ def classify_confirmed_tracks(self, video_path, confirmed_track_ids, crops_dir=None):
467
468
  """
468
469
  Classify only the confirmed tracks by re-reading relevant frames.
469
470
 
470
471
  Args:
471
472
  video_path: Path to original video
472
473
  confirmed_track_ids: Set of track IDs that passed topology analysis
474
+ crops_dir: Optional directory to save cropped frames
473
475
 
474
476
  Returns:
475
477
  dict: track_id -> list of classifications
@@ -480,6 +482,15 @@ class VideoInferenceProcessor:
480
482
 
481
483
  print(f"\nClassifying {len(confirmed_track_ids)} confirmed tracks...")
482
484
 
485
+ # Setup crops directory if requested
486
+ if crops_dir:
487
+ os.makedirs(crops_dir, exist_ok=True)
488
+ # Create subdirectory for each track
489
+ for track_id in confirmed_track_ids:
490
+ track_dir = os.path.join(crops_dir, str(track_id)[:8])
491
+ os.makedirs(track_dir, exist_ok=True)
492
+ print(f" Saving crops to: {crops_dir}")
493
+
483
494
  # Group detections by frame for confirmed tracks
484
495
  frames_to_classify = defaultdict(list)
485
496
  for det in self.all_detections:
@@ -518,11 +529,22 @@ class VideoInferenceProcessor:
518
529
  track_classifications[det['track_id']].append(classification)
519
530
  classified_count += 1
520
531
 
532
+ # Save crop if requested
533
+ if crops_dir:
534
+ track_id = det['track_id']
535
+ track_dir = os.path.join(crops_dir, str(track_id)[:8])
536
+ crop = frame[int(y1):int(y2), int(x1):int(x2)]
537
+ if crop.size > 0:
538
+ crop_path = os.path.join(track_dir, f"frame_{target_frame:06d}.jpg")
539
+ cv2.imwrite(crop_path, crop)
540
+
521
541
  if classified_count % 20 == 0:
522
542
  print(f" Classified {classified_count} detections...", end='\r')
523
543
 
524
544
  cap.release()
525
545
  print(f"\n✓ Classified {classified_count} detections from {len(confirmed_track_ids)} tracks")
546
+ if crops_dir:
547
+ print(f"✓ Saved {classified_count} crops to {crops_dir}")
526
548
 
527
549
  return track_classifications
528
550
 
@@ -699,7 +721,7 @@ class VideoInferenceProcessor:
699
721
  print("\n" + "="*60)
700
722
  print("🐛 FINAL SUMMARY")
701
723
  print("="*60)
702
-
724
+
703
725
  if results:
704
726
  print(f"\n✓ CONFIRMED INSECTS ({num_confirmed}):")
705
727
  for r in results:
@@ -736,7 +758,7 @@ class VideoInferenceProcessor:
736
758
  # VIDEO PROCESSING
737
759
  # ============================================================================
738
760
 
739
- def process_video(video_path, processor, output_paths, show_video=False, fps=None):
761
+ def process_video(video_path, processor, output_paths, show_video=False, fps=None, crops_dir=None):
740
762
  """
741
763
  Process video file with efficient classification (confirmed tracks only).
742
764
 
@@ -752,6 +774,7 @@ def process_video(video_path, processor, output_paths, show_video=False, fps=Non
752
774
  output_paths: Dict with output file paths
753
775
  show_video: Display video while processing
754
776
  fps: Target FPS (skip frames if lower than input)
777
+ crops_dir: Optional directory to save cropped frames for each track
755
778
 
756
779
  Returns:
757
780
  list: Aggregated results
@@ -832,7 +855,7 @@ def process_video(video_path, processor, output_paths, show_video=False, fps=Non
832
855
  print("="*60)
833
856
 
834
857
  if confirmed_track_ids:
835
- processor.classify_confirmed_tracks(video_path, confirmed_track_ids)
858
+ processor.classify_confirmed_tracks(video_path, confirmed_track_ids, crops_dir=crops_dir)
836
859
  results = processor.hierarchical_aggregation(confirmed_track_ids)
837
860
  else:
838
861
  results = []
@@ -840,23 +863,29 @@ def process_video(video_path, processor, output_paths, show_video=False, fps=Non
840
863
  # ==========================================================================
841
864
  # PHASE 4: Render Videos
842
865
  # ==========================================================================
843
- print("\n" + "="*60)
844
- print("PHASE 4: RENDERING VIDEOS")
845
- print("="*60)
846
-
847
- # Render debug video (all detections, showing confirmed vs unconfirmed)
848
- print(f"\nRendering debug video (all detections)...")
849
- _render_debug_video(
850
- video_path, output_paths["debug_video"],
851
- processor, confirmed_track_ids, all_track_info, input_fps
852
- )
853
-
854
- # Render annotated video (confirmed tracks with classifications)
855
- print(f"\nRendering annotated video ({len(confirmed_track_ids)} confirmed tracks)...")
856
- _render_annotated_video(
857
- video_path, output_paths["annotated_video"],
858
- processor, confirmed_track_ids, input_fps
859
- )
866
+ # Render videos if requested
867
+ if "annotated_video" in output_paths or "debug_video" in output_paths:
868
+ print("\n" + "="*60)
869
+ print("PHASE 4: RENDERING VIDEOS")
870
+ print("="*60)
871
+
872
+ # Render debug video (all detections, showing confirmed vs unconfirmed)
873
+ if "debug_video" in output_paths:
874
+ print(f"\nRendering debug video (all detections)...")
875
+ _render_debug_video(
876
+ video_path, output_paths["debug_video"],
877
+ processor, confirmed_track_ids, all_track_info, input_fps
878
+ )
879
+
880
+ # Render annotated video (confirmed tracks with classifications)
881
+ if "annotated_video" in output_paths:
882
+ print(f"\nRendering annotated video ({len(confirmed_track_ids)} confirmed tracks)...")
883
+ _render_annotated_video(
884
+ video_path, output_paths["annotated_video"],
885
+ processor, confirmed_track_ids, input_fps
886
+ )
887
+ else:
888
+ print("\n(Video rendering skipped)")
860
889
 
861
890
  # Save results
862
891
  processor.save_results(results, output_paths)
@@ -1050,6 +1079,9 @@ def inference(
1050
1079
  fps=None,
1051
1080
  config=None,
1052
1081
  backbone="resnet50",
1082
+ crops=False,
1083
+ save_video=True,
1084
+ img_size=60,
1053
1085
  ):
1054
1086
  """
1055
1087
  Run inference on a video file.
@@ -1066,20 +1098,24 @@ def inference(
1066
1098
  - dict: config parameters directly
1067
1099
  backbone: ResNet backbone ('resnet18', 'resnet50', 'resnet101').
1068
1100
  If model checkpoint contains backbone info, it will be used instead.
1101
+ crops: If True, save cropped frames for each classified track
1102
+ save_video: If True, save annotated and debug videos. Defaults to True.
1103
+ img_size: Image size for classification (should match training). Default: 60.
1069
1104
 
1070
1105
  Returns:
1071
1106
  dict: Processing results with output file paths
1072
1107
 
1073
1108
  Generated files in output_dir:
1074
- - {video_name}_annotated.mp4: Video with detection boxes and paths
1075
- - {video_name}_debug.mp4: Side-by-side with GMM motion mask
1109
+ - {video_name}_annotated.mp4: Video with detection boxes and paths (if save_video=True)
1110
+ - {video_name}_debug.mp4: Side-by-side with GMM motion mask (if save_video=True)
1076
1111
  - {video_name}_results.csv: Aggregated track results
1077
1112
  - {video_name}_detections.csv: Frame-by-frame detections
1113
+ - {video_name}_crops/ (if crops=True): Directory with cropped frames per track
1078
1114
  """
1079
1115
  if not os.path.exists(video_path):
1080
1116
  print(f"Error: Video not found: {video_path}")
1081
1117
  return {"error": f"Video not found: {video_path}", "success": False}
1082
-
1118
+
1083
1119
  # Build parameters from config
1084
1120
  if config is None:
1085
1121
  params = get_default_config()
@@ -1096,16 +1132,23 @@ def inference(
1096
1132
  raise ValueError("config must be None, a file path (str), or a dict")
1097
1133
 
1098
1134
  # Setup output directory and file paths
1099
- os.makedirs(output_dir, exist_ok=True)
1135
+ os.makedirs(output_dir, exist_ok=True)
1100
1136
  video_name = os.path.splitext(os.path.basename(video_path))[0]
1101
1137
 
1102
1138
  output_paths = {
1103
- "annotated_video": os.path.join(output_dir, f"{video_name}_annotated.mp4"),
1104
- "debug_video": os.path.join(output_dir, f"{video_name}_debug.mp4"),
1105
1139
  "results_csv": os.path.join(output_dir, f"{video_name}_results.csv"),
1106
1140
  "detections_csv": os.path.join(output_dir, f"{video_name}_detections.csv"),
1107
1141
  }
1108
1142
 
1143
+ if save_video:
1144
+ output_paths["annotated_video"] = os.path.join(output_dir, f"{video_name}_annotated.mp4")
1145
+ output_paths["debug_video"] = os.path.join(output_dir, f"{video_name}_debug.mp4")
1146
+
1147
+ # Setup crops directory if requested
1148
+ crops_dir = os.path.join(output_dir, f"{video_name}_crops") if crops else None
1149
+ if crops_dir:
1150
+ output_paths["crops_dir"] = crops_dir
1151
+
1109
1152
  print("\n" + "="*60)
1110
1153
  print("BPLUSPLUS INFERENCE")
1111
1154
  print("="*60)
@@ -1126,14 +1169,16 @@ def inference(
1126
1169
  hierarchical_model_path=hierarchical_model_path,
1127
1170
  params=params,
1128
1171
  backbone=backbone,
1172
+ img_size=img_size,
1129
1173
  )
1130
-
1174
+
1131
1175
  try:
1132
1176
  results = process_video(
1133
1177
  video_path=video_path,
1134
1178
  processor=processor,
1135
1179
  output_paths=output_paths,
1136
- fps=fps
1180
+ fps=fps,
1181
+ crops_dir=crops_dir
1137
1182
  )
1138
1183
 
1139
1184
  return {
@@ -1174,6 +1219,7 @@ Output files generated in output directory:
1174
1219
  - {video_name}_debug.mp4: Side-by-side view with GMM motion mask
1175
1220
  - {video_name}_results.csv: Aggregated track results
1176
1221
  - {video_name}_detections.csv: Frame-by-frame detections
1222
+ - {video_name}_crops/ (with --crops): Cropped frames for each track
1177
1223
  """
1178
1224
  )
1179
1225
 
@@ -1192,6 +1238,12 @@ Output files generated in output directory:
1192
1238
  parser.add_argument('--backbone', '-b', default='resnet50',
1193
1239
  choices=['resnet18', 'resnet50', 'resnet101'],
1194
1240
  help='ResNet backbone (default: resnet50, overridden by checkpoint if saved)')
1241
+ parser.add_argument('--crops', action='store_true',
1242
+ help='Save cropped frames for each classified track')
1243
+ parser.add_argument('--no-video', action='store_true',
1244
+ help='Skip saving annotated and debug videos')
1245
+ parser.add_argument('--img-size', type=int, default=60,
1246
+ help='Image size for classification (should match training, default: 60)')
1195
1247
 
1196
1248
  # Detection parameters (override config)
1197
1249
  defaults = DEFAULT_DETECTION_CONFIG
@@ -1267,6 +1319,9 @@ Output files generated in output directory:
1267
1319
  fps=args.fps,
1268
1320
  config=config,
1269
1321
  backbone=args.backbone,
1322
+ crops=args.crops,
1323
+ save_video=not args.no_video,
1324
+ img_size=args.img_size,
1270
1325
  )
1271
1326
 
1272
1327
  if result.get("success"):
@@ -1279,4 +1334,4 @@ Output files generated in output directory:
1279
1334
 
1280
1335
 
1281
1336
  if __name__ == "__main__":
1282
- main()
1337
+ main()
@@ -7,7 +7,7 @@ from typing import Optional
7
7
 
8
8
  import requests
9
9
  import torch
10
- from PIL import Image
10
+ from PIL import Image, ImageFilter
11
11
  from torch import serialization
12
12
  from torch.nn import Module, ModuleDict, ModuleList
13
13
  from torch.nn.modules.activation import LeakyReLU, ReLU, SiLU
@@ -27,7 +27,7 @@ from ultralytics.nn.modules.conv import Conv
27
27
  from ultralytics.nn.tasks import DetectionModel
28
28
 
29
29
 
30
- def prepare(input_directory: str, output_directory: str, img_size: int = 40, conf: float = 0.35, valid: float = 0.1):
30
+ def prepare(input_directory: str, output_directory: str, img_size: int = 40, conf: float = 0.35, valid: float = 0.1, blur: Optional[float] = None):
31
31
  """
32
32
  Prepares a YOLO classification dataset by performing the following steps:
33
33
  1. Copies images from input directory to temporary directory and creates class mapping.
@@ -44,10 +44,16 @@ def prepare(input_directory: str, output_directory: str, img_size: int = 40, con
44
44
  conf (float, optional): YOLO detection confidence threshold. Defaults to 0.35.
45
45
  valid (float, optional): Fraction of data for validation (0.0 to 1.0).
46
46
  0 = no validation split, 0.1 = 10% validation. Defaults to 0.1.
47
+ blur (float, optional): Gaussian blur as fraction of image size (0.0 to 1.0).
48
+ Applied before resizing. 0.01 = 1% of smallest dimension.
49
+ None or 0 means no blur. Defaults to None.
47
50
  """
48
51
  # Validate the valid parameter
49
52
  if not 0 <= valid <= 1:
50
53
  raise ValueError(f"valid must be between 0 and 1, got {valid}")
54
+ # Validate the blur parameter
55
+ if blur is not None and not 0 <= blur <= 1:
56
+ raise ValueError(f"blur must be between 0 and 1, got {blur}")
51
57
  input_directory = Path(input_directory)
52
58
  output_directory = Path(output_directory)
53
59
 
@@ -62,6 +68,10 @@ def prepare(input_directory: str, output_directory: str, img_size: int = 40, con
62
68
  print(f"Validation split: {valid*100:.0f}% validation, {(1-valid)*100:.0f}% training")
63
69
  else:
64
70
  print("Validation split: disabled (all images to training)")
71
+ if blur and blur > 0:
72
+ print(f"Gaussian blur: {blur*100:.1f}% of image size")
73
+ else:
74
+ print("Gaussian blur: disabled")
65
75
  print()
66
76
 
67
77
  with tempfile.TemporaryDirectory() as temp_dir:
@@ -106,7 +116,7 @@ def prepare(input_directory: str, output_directory: str, img_size: int = 40, con
106
116
  print("-" * 50)
107
117
  _finalize_dataset(
108
118
  class_mapping, temp_dir_path, output_directory,
109
- class_idxs, original_image_count, img_size, valid
119
+ class_idxs, original_image_count, img_size, valid, blur
110
120
  )
111
121
  print("✓ Step 5 completed: Classification dataset ready!")
112
122
  print()
@@ -137,7 +147,7 @@ def _setup_directories_and_copy_images(input_directory: Path, temp_dir_path: Pat
137
147
  images_names = []
138
148
  if folder_directory.is_dir():
139
149
  folder_name = folder_directory.name
140
- image_files = list(folder_directory.glob("*.jpg"))
150
+ image_files = list(folder_directory.glob("*.jpg")) + list(folder_directory.glob("*.png"))
141
151
  print(f" Copying {len(image_files)} images from class '{folder_name}'...")
142
152
 
143
153
  for image_file in image_files:
@@ -149,7 +159,7 @@ def _setup_directories_and_copy_images(input_directory: Path, temp_dir_path: Pat
149
159
  class_mapping[folder_name] = images_names
150
160
  print(f" ✓ {len(images_names)} images copied for class '{folder_name}'")
151
161
 
152
- original_image_count = len(list(images_path.glob("*.jpg"))) + len(list(images_path.glob("*.jpeg")))
162
+ original_image_count = len(list(images_path.glob("*.jpg"))) + len(list(images_path.glob("*.jpeg"))) + len(list(images_path.glob("*.png")))
153
163
  print(f" Total images in temporary directory: {original_image_count}")
154
164
 
155
165
  return class_mapping, original_image_count
@@ -165,9 +175,9 @@ def _prepare_model_and_clean_images(temp_dir_path: Path):
165
175
 
166
176
  # Clean corrupted images
167
177
  print(" Checking for corrupted images...")
168
- images_before = len(list(images_path.glob("*.jpg")))
178
+ images_before = len(list(images_path.glob("*.jpg"))) + len(list(images_path.glob("*.png")))
169
179
  __delete_corrupted_images(images_path)
170
- images_after = len(list(images_path.glob("*.jpg")))
180
+ images_after = len(list(images_path.glob("*.jpg"))) + len(list(images_path.glob("*.png")))
171
181
  deleted_count = images_before - images_after
172
182
  print(f" ✓ Cleaned {deleted_count} corrupted images ({images_after} images remain)")
173
183
 
@@ -208,7 +218,7 @@ def _run_yolo_inference(temp_dir_path: Path, weights_path: Path, conf: float):
208
218
  temp_dir_path (Path): Path to the working temp directory.
209
219
  weights_path (Path): Path to YOLO weights.
210
220
  conf (float): YOLO detection confidence threshold.
211
-
221
+
212
222
  Returns:
213
223
  Path: labels_path where the generated labels are stored
214
224
  """
@@ -221,7 +231,7 @@ def _run_yolo_inference(temp_dir_path: Path, weights_path: Path, conf: float):
221
231
  print(" ✓ YOLO model loaded successfully")
222
232
 
223
233
  # Get list of all image files
224
- image_files = list(images_path.glob('*.jpg'))
234
+ image_files = list(images_path.glob('*.jpg')) + list(images_path.glob('*.png'))
225
235
  print(f" Found {len(image_files)} images to process with YOLO")
226
236
 
227
237
  # Ensure predict directory exists
@@ -288,13 +298,13 @@ def _cleanup_and_process_labels(temp_dir_path: Path, labels_path: Path, class_ma
288
298
  images_path = temp_dir_path / "images"
289
299
 
290
300
  print(" Cleaning up orphaned images and labels...")
291
- images_before = len(list(images_path.glob("*.jpg")))
301
+ images_before = len(list(images_path.glob("*.jpg"))) + len(list(images_path.glob("*.png")))
292
302
  labels_before = len(list(labels_path.glob("*.txt")))
293
303
 
294
304
  __delete_orphaned_images_and_inferences(images_path, labels_path)
295
305
  __delete_invalid_txt_files(images_path, labels_path)
296
306
 
297
- images_after = len(list(images_path.glob("*.jpg")))
307
+ images_after = len(list(images_path.glob("*.jpg"))) + len(list(images_path.glob("*.png")))
298
308
  labels_after = len(list(labels_path.glob("*.txt")))
299
309
 
300
310
  deleted_images = images_before - images_after
@@ -313,15 +323,16 @@ def _cleanup_and_process_labels(temp_dir_path: Path, labels_path: Path, class_ma
313
323
 
314
324
  def _finalize_dataset(class_mapping: dict, temp_dir_path: Path, output_directory: Path,
315
325
  class_idxs: dict, original_image_count: int, img_size: int,
316
- valid_fraction: float = 0.1):
326
+ valid_fraction: float = 0.1, blur: Optional[float] = None):
317
327
  """
318
328
  Finalizes the dataset by creating cropped classification images and splitting into train/valid sets.
319
329
 
320
330
  Args:
321
331
  valid_fraction: Fraction of data for validation (0.0 to 1.0). 0 = no validation split.
332
+ blur: Gaussian blur as fraction of image size (0-1). None or 0 means no blur.
322
333
  """
323
334
  # Split data into train/valid with cropped classification images
324
- __classification_split(class_mapping, temp_dir_path, output_directory, img_size, valid_fraction)
335
+ __classification_split(class_mapping, temp_dir_path, output_directory, img_size, valid_fraction, blur)
325
336
 
326
337
  # Generate final report
327
338
  print(" Generating final statistics...")
@@ -345,11 +356,12 @@ def __delete_corrupted_images(images_path: Path):
345
356
  it cannot be opened), the function deletes the corrupted image file.
346
357
  """
347
358
 
348
- for image_file in images_path.glob("*.jpg"):
349
- try:
350
- Image.open(image_file)
351
- except IOError:
352
- image_file.unlink()
359
+ for pattern in ["*.jpg", "*.png"]:
360
+ for image_file in images_path.glob(pattern):
361
+ try:
362
+ Image.open(image_file)
363
+ except IOError:
364
+ image_file.unlink()
353
365
 
354
366
  def __download_file_from_github_release(url, dest_path):
355
367
 
@@ -399,13 +411,14 @@ def __delete_orphaned_images_and_inferences(images_path: Path, labels_path: Path
399
411
  for txt_file in labels_path.glob("*.txt"):
400
412
  image_file_jpg = images_path / (txt_file.stem + ".jpg")
401
413
  image_file_jpeg = images_path / (txt_file.stem + ".jpeg")
414
+ image_file_png = images_path / (txt_file.stem + ".png")
402
415
 
403
- if not (image_file_jpg.exists() or image_file_jpeg.exists()):
416
+ if not (image_file_jpg.exists() or image_file_jpeg.exists() or image_file_png.exists()):
404
417
  # print(f"Deleting {txt_file.name} - No corresponding image file")
405
418
  txt_file.unlink()
406
419
 
407
420
  label_stems = {txt_file.stem for txt_file in labels_path.glob("*.txt")}
408
- image_files = list(images_path.glob("*.jpg")) + list(images_path.glob("*.jpeg"))
421
+ image_files = list(images_path.glob("*.jpg")) + list(images_path.glob("*.jpeg")) + list(images_path.glob("*.png"))
409
422
 
410
423
  for image_file in image_files:
411
424
  if image_file.stem not in label_stems:
@@ -439,6 +452,7 @@ def __delete_invalid_txt_files(images_path: Path, labels_path: Path):
439
452
 
440
453
  image_file_jpg = images_path / (txt_file.stem + ".jpg")
441
454
  image_file_jpeg = images_path / (txt_file.stem + ".jpeg")
455
+ image_file_png = images_path / (txt_file.stem + ".png")
442
456
 
443
457
  if image_file_jpg.exists():
444
458
  image_file_jpg.unlink()
@@ -446,11 +460,14 @@ def __delete_invalid_txt_files(images_path: Path, labels_path: Path):
446
460
  elif image_file_jpeg.exists():
447
461
  image_file_jpeg.unlink()
448
462
  # print(f"Deleted corresponding image file: {image_file_jpeg.name}")
463
+ elif image_file_png.exists():
464
+ image_file_png.unlink()
465
+ # print(f"Deleted corresponding image file: {image_file_png.name}")
449
466
 
450
467
 
451
468
 
452
469
 
453
- def __classification_split(class_mapping: dict, temp_dir_path: Path, output_directory: Path, img_size: int, valid_fraction: float = 0.1):
470
+ def __classification_split(class_mapping: dict, temp_dir_path: Path, output_directory: Path, img_size: int, valid_fraction: float = 0.1, blur: Optional[float] = None):
454
471
  """
455
472
  Splits the data into train and validation sets for classification tasks,
456
473
  cropping images according to their YOLO labels but preserving original class structure.
@@ -461,6 +478,7 @@ def __classification_split(class_mapping: dict, temp_dir_path: Path, output_dire
461
478
  output_directory (Path): The path to the output directory where train and valid splits will be created.
462
479
  img_size (int): The target size for the smallest dimension of cropped images.
463
480
  valid_fraction (float): Fraction of data for validation (0.0 to 1.0). 0 = no validation split.
481
+ blur (float, optional): Gaussian blur as fraction of image size (0-1). None or 0 means no blur.
464
482
  """
465
483
  images_dir = temp_dir_path / "images"
466
484
  labels_dir = temp_dir_path / "predict" / "labels"
@@ -544,6 +562,12 @@ def __classification_split(class_mapping: dict, temp_dir_path: Path, output_dire
544
562
 
545
563
  img = img.crop((x_min, y_min, x_max, y_max))
546
564
 
565
+ # Apply Gaussian blur if specified (blur is fraction of smallest dimension)
566
+ if blur and blur > 0:
567
+ img_width, img_height = img.size
568
+ blur_radius = blur * min(img_width, img_height)
569
+ img = img.filter(ImageFilter.GaussianBlur(radius=blur_radius))
570
+
547
571
  img_width, img_height = img.size
548
572
  if img_width < img_height:
549
573
  # Width is smaller, set to img_size
@@ -677,6 +701,6 @@ def count_images_across_splits(output_directory: Path) -> int:
677
701
  # Count all images in all class subdirectories
678
702
  for class_dir in split_dir.iterdir():
679
703
  if class_dir.is_dir():
680
- total_images += len(list(class_dir.glob("*.jpg"))) + len(list(class_dir.glob("*.jpeg")))
704
+ total_images += len(list(class_dir.glob("*.jpg"))) + len(list(class_dir.glob("*.jpeg"))) + len(list(class_dir.glob("*.png")))
681
705
 
682
706
  return total_images
@@ -74,7 +74,7 @@ def train(batch_size=4, epochs=30, patience=3, img_size=640, data_dir='input', o
74
74
  if any(f.lower().endswith(('.jpg', '.jpeg', '.png')) for f in files):
75
75
  return True
76
76
  return False
77
-
77
+
78
78
  train_dataset = InsectDataset(
79
79
  root_dir=train_dir,
80
80
  transform=train_transforms or get_transforms(is_training=True, img_size=img_size),
@@ -107,7 +107,7 @@ def train(batch_size=4, epochs=30, patience=3, img_size=640, data_dir='input', o
107
107
  shuffle=False,
108
108
  num_workers=num_workers
109
109
  )
110
-
110
+
111
111
  train_loader = DataLoader(
112
112
  train_dataset,
113
113
  batch_size=batch_size,
@@ -819,7 +819,7 @@ def train_model(model, train_loader, val_loader, criterion, optimizer, level_to_
819
819
  print("Validation skipped (no valid data found).")
820
820
  print('-' * 60)
821
821
  continue
822
-
822
+
823
823
  model.eval()
824
824
  val_running_loss = 0.0
825
825
  val_correct_predictions = [0] * model.num_levels
@@ -899,7 +899,7 @@ def train_model(model, train_loader, val_loader, criterion, optimizer, level_to_
899
899
  'backbone': backbone
900
900
  }, best_model_path)
901
901
  logger.info(f"Saved model (validation skipped) at {best_model_path}")
902
-
902
+
903
903
  logger.info("Training completed successfully")
904
904
  return model
905
905
 
File without changes