sleap-nn 0.1.0a2__py3-none-any.whl → 0.1.0a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. sleap_nn/__init__.py +1 -1
  2. sleap_nn/architectures/convnext.py +5 -0
  3. sleap_nn/architectures/encoder_decoder.py +25 -6
  4. sleap_nn/architectures/swint.py +8 -0
  5. sleap_nn/cli.py +168 -39
  6. sleap_nn/evaluation.py +8 -0
  7. sleap_nn/export/__init__.py +21 -0
  8. sleap_nn/export/cli.py +1778 -0
  9. sleap_nn/export/exporters/__init__.py +51 -0
  10. sleap_nn/export/exporters/onnx_exporter.py +80 -0
  11. sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
  12. sleap_nn/export/metadata.py +225 -0
  13. sleap_nn/export/predictors/__init__.py +63 -0
  14. sleap_nn/export/predictors/base.py +22 -0
  15. sleap_nn/export/predictors/onnx.py +154 -0
  16. sleap_nn/export/predictors/tensorrt.py +312 -0
  17. sleap_nn/export/utils.py +307 -0
  18. sleap_nn/export/wrappers/__init__.py +25 -0
  19. sleap_nn/export/wrappers/base.py +96 -0
  20. sleap_nn/export/wrappers/bottomup.py +243 -0
  21. sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
  22. sleap_nn/export/wrappers/centered_instance.py +56 -0
  23. sleap_nn/export/wrappers/centroid.py +58 -0
  24. sleap_nn/export/wrappers/single_instance.py +83 -0
  25. sleap_nn/export/wrappers/topdown.py +180 -0
  26. sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
  27. sleap_nn/inference/peak_finding.py +47 -17
  28. sleap_nn/inference/postprocessing.py +284 -0
  29. sleap_nn/inference/predictors.py +213 -106
  30. sleap_nn/predict.py +35 -7
  31. sleap_nn/train.py +64 -0
  32. sleap_nn/training/callbacks.py +69 -22
  33. sleap_nn/training/lightning_modules.py +332 -30
  34. sleap_nn/training/model_trainer.py +67 -67
  35. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/METADATA +13 -1
  36. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/RECORD +40 -19
  37. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/WHEEL +0 -0
  38. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/entry_points.txt +0 -0
  39. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/licenses/LICENSE +0 -0
  40. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/top_level.txt +0 -0
@@ -56,6 +56,8 @@ from rich.progress import (
56
56
  MofNCompleteColumn,
57
57
  )
58
58
  from time import time
59
+ import json
60
+ import sys
59
61
 
60
62
 
61
63
  def _filter_user_labeled_frames(
@@ -133,6 +135,8 @@ class Predictor(ABC):
133
135
  `backbone_config`. This determines the downsampling factor applied by the backbone,
134
136
  and is used to ensure that input images are padded or resized to be compatible
135
137
  with the model's architecture. Default: 16.
138
+ gui: If True, outputs JSON progress lines for GUI integration instead of
139
+ Rich progress bars. Default: False.
136
140
  """
137
141
 
138
142
  preprocess: bool = True
@@ -152,6 +156,7 @@ class Predictor(ABC):
152
156
  ] = None
153
157
  instances_key: bool = False
154
158
  max_stride: int = 16
159
+ gui: bool = False
155
160
 
156
161
  @classmethod
157
162
  def from_model_paths(
@@ -381,6 +386,102 @@ class Predictor(ABC):
381
386
  v[n] = v[n].cpu().numpy()
382
387
  return output
383
388
 
389
+ def _process_batch(self) -> tuple:
390
+ """Process a single batch of frames from the pipeline.
391
+
392
+ Returns:
393
+ Tuple of (imgs, fidxs, vidxs, org_szs, instances, eff_scales, done)
394
+ where done is True if the pipeline has finished.
395
+ """
396
+ imgs = []
397
+ fidxs = []
398
+ vidxs = []
399
+ org_szs = []
400
+ instances = []
401
+ eff_scales = []
402
+ done = False
403
+
404
+ for _ in range(self.batch_size):
405
+ frame = self.pipeline.frame_buffer.get()
406
+ if frame["image"] is None:
407
+ done = True
408
+ break
409
+ frame["image"], eff_scale = apply_sizematcher(
410
+ frame["image"],
411
+ self.preprocess_config["max_height"],
412
+ self.preprocess_config["max_width"],
413
+ )
414
+ if self.instances_key:
415
+ frame["instances"] = frame["instances"] * eff_scale
416
+ if self.preprocess_config["ensure_rgb"] and frame["image"].shape[-3] != 3:
417
+ frame["image"] = frame["image"].repeat(1, 3, 1, 1)
418
+ elif (
419
+ self.preprocess_config["ensure_grayscale"]
420
+ and frame["image"].shape[-3] != 1
421
+ ):
422
+ frame["image"] = F.rgb_to_grayscale(
423
+ frame["image"], num_output_channels=1
424
+ )
425
+
426
+ eff_scales.append(torch.tensor(eff_scale))
427
+ imgs.append(frame["image"].unsqueeze(dim=0))
428
+ fidxs.append(frame["frame_idx"])
429
+ vidxs.append(frame["video_idx"])
430
+ org_szs.append(frame["orig_size"].unsqueeze(dim=0))
431
+ if self.instances_key:
432
+ instances.append(frame["instances"].unsqueeze(dim=0))
433
+
434
+ return imgs, fidxs, vidxs, org_szs, instances, eff_scales, done
435
+
436
+ def _run_inference_on_batch(
437
+ self, imgs, fidxs, vidxs, org_szs, instances, eff_scales
438
+ ) -> Iterator[Dict[str, np.ndarray]]:
439
+ """Run inference on a prepared batch of frames.
440
+
441
+ Args:
442
+ imgs: List of image tensors.
443
+ fidxs: List of frame indices.
444
+ vidxs: List of video indices.
445
+ org_szs: List of original sizes.
446
+ instances: List of instance tensors.
447
+ eff_scales: List of effective scales.
448
+
449
+ Yields:
450
+ Dictionaries containing inference results for each frame.
451
+ """
452
+ # TODO: all preprocessing should be moved into InferenceModels to be exportable.
453
+ imgs = torch.concatenate(imgs, dim=0)
454
+ fidxs = torch.tensor(fidxs, dtype=torch.int32)
455
+ vidxs = torch.tensor(vidxs, dtype=torch.int32)
456
+ org_szs = torch.concatenate(org_szs, dim=0)
457
+ eff_scales = torch.tensor(eff_scales, dtype=torch.float32)
458
+ if self.instances_key:
459
+ instances = torch.concatenate(instances, dim=0)
460
+ ex = {
461
+ "image": imgs,
462
+ "frame_idx": fidxs,
463
+ "video_idx": vidxs,
464
+ "orig_size": org_szs,
465
+ "eff_scale": eff_scales,
466
+ }
467
+ if self.instances_key:
468
+ ex["instances"] = instances
469
+ if self.preprocess:
470
+ scale = self.preprocess_config["scale"]
471
+ if scale != 1.0:
472
+ if self.instances_key:
473
+ ex["image"], ex["instances"] = apply_resizer(
474
+ ex["image"], ex["instances"]
475
+ )
476
+ else:
477
+ ex["image"] = resize_image(ex["image"], scale)
478
+ ex["image"] = apply_pad_to_stride(ex["image"], self.max_stride)
479
+ outputs_list = self.inference_model(ex)
480
+ if outputs_list is not None:
481
+ for output in outputs_list:
482
+ output = self._convert_tensors_to_numpy(output)
483
+ yield output
484
+
384
485
  def _predict_generator(self) -> Iterator[Dict[str, np.ndarray]]:
385
486
  """Create a generator that yields batches of inference results.
386
487
 
@@ -400,114 +501,14 @@ class Predictor(ABC):
400
501
  # Loop over data batches.
401
502
  self.pipeline.start()
402
503
  total_frames = self.pipeline.total_len()
403
- done = False
404
504
 
405
505
  try:
406
- with Progress(
407
- "{task.description}",
408
- BarColumn(),
409
- "[progress.percentage]{task.percentage:>3.0f}%",
410
- MofNCompleteColumn(),
411
- "ETA:",
412
- TimeRemainingColumn(),
413
- "Elapsed:",
414
- TimeElapsedColumn(),
415
- RateColumn(),
416
- auto_refresh=False,
417
- refresh_per_second=4, # Change to self.report_rate if needed
418
- speed_estimate_period=5,
419
- ) as progress:
420
- task = progress.add_task("Predicting...", total=total_frames)
421
- last_report = time()
422
-
423
- done = False
424
- while not done:
425
- imgs = []
426
- fidxs = []
427
- vidxs = []
428
- org_szs = []
429
- instances = []
430
- eff_scales = []
431
- for _ in range(self.batch_size):
432
- frame = self.pipeline.frame_buffer.get()
433
- if frame["image"] is None:
434
- done = True
435
- break
436
- frame["image"], eff_scale = apply_sizematcher(
437
- frame["image"],
438
- self.preprocess_config["max_height"],
439
- self.preprocess_config["max_width"],
440
- )
441
- if self.instances_key:
442
- frame["instances"] = frame["instances"] * eff_scale
443
- if (
444
- self.preprocess_config["ensure_rgb"]
445
- and frame["image"].shape[-3] != 3
446
- ):
447
- frame["image"] = frame["image"].repeat(1, 3, 1, 1)
448
- elif (
449
- self.preprocess_config["ensure_grayscale"]
450
- and frame["image"].shape[-3] != 1
451
- ):
452
- frame["image"] = F.rgb_to_grayscale(
453
- frame["image"], num_output_channels=1
454
- )
455
-
456
- eff_scales.append(torch.tensor(eff_scale))
457
- imgs.append(frame["image"].unsqueeze(dim=0))
458
- fidxs.append(frame["frame_idx"])
459
- vidxs.append(frame["video_idx"])
460
- org_szs.append(frame["orig_size"].unsqueeze(dim=0))
461
- if self.instances_key:
462
- instances.append(frame["instances"].unsqueeze(dim=0))
463
- if imgs:
464
- # TODO: all preprocessing should be moved into InferenceModels to be exportable.
465
- imgs = torch.concatenate(imgs, dim=0)
466
- fidxs = torch.tensor(fidxs, dtype=torch.int32)
467
- vidxs = torch.tensor(vidxs, dtype=torch.int32)
468
- org_szs = torch.concatenate(org_szs, dim=0)
469
- eff_scales = torch.tensor(eff_scales, dtype=torch.float32)
470
- if self.instances_key:
471
- instances = torch.concatenate(instances, dim=0)
472
- ex = {
473
- "image": imgs,
474
- "frame_idx": fidxs,
475
- "video_idx": vidxs,
476
- "orig_size": org_szs,
477
- "eff_scale": eff_scales,
478
- }
479
- if self.instances_key:
480
- ex["instances"] = instances
481
- if self.preprocess:
482
- scale = self.preprocess_config["scale"]
483
- if scale != 1.0:
484
- if self.instances_key:
485
- ex["image"], ex["instances"] = apply_resizer(
486
- ex["image"], ex["instances"]
487
- )
488
- else:
489
- ex["image"] = resize_image(ex["image"], scale)
490
- ex["image"] = apply_pad_to_stride(
491
- ex["image"], self.max_stride
492
- )
493
- outputs_list = self.inference_model(ex)
494
- if outputs_list is not None:
495
- for output in outputs_list:
496
- output = self._convert_tensors_to_numpy(output)
497
- yield output
498
-
499
- # Advance progress
500
- num_frames = (
501
- len(ex["frame_idx"])
502
- if "frame_idx" in ex
503
- else self.batch_size
504
- )
505
- progress.update(task, advance=num_frames)
506
-
507
- # Manually refresh progress bar
508
- if time() - last_report > 0.25:
509
- progress.refresh()
510
- last_report = time()
506
+ if self.gui:
507
+ # GUI mode: emit JSON progress lines
508
+ yield from self._predict_generator_gui(total_frames)
509
+ else:
510
+ # Normal mode: use Rich progress bar
511
+ yield from self._predict_generator_rich(total_frames)
511
512
 
512
513
  except KeyboardInterrupt:
513
514
  logger.info("Inference interrupted by user")
@@ -520,6 +521,112 @@ class Predictor(ABC):
520
521
 
521
522
  self.pipeline.join()
522
523
 
524
+ def _predict_generator_gui(
525
+ self, total_frames: int
526
+ ) -> Iterator[Dict[str, np.ndarray]]:
527
+ """Generator for GUI mode with JSON progress output.
528
+
529
+ Args:
530
+ total_frames: Total number of frames to process.
531
+
532
+ Yields:
533
+ Dictionaries containing inference results for each frame.
534
+ """
535
+ start_time = time()
536
+ frames_processed = 0
537
+ last_report = time()
538
+ done = False
539
+
540
+ while not done:
541
+ imgs, fidxs, vidxs, org_szs, instances, eff_scales, done = (
542
+ self._process_batch()
543
+ )
544
+
545
+ if imgs:
546
+ yield from self._run_inference_on_batch(
547
+ imgs, fidxs, vidxs, org_szs, instances, eff_scales
548
+ )
549
+
550
+ # Update progress
551
+ num_frames = len(fidxs)
552
+ frames_processed += num_frames
553
+
554
+ # Emit JSON progress (throttled to ~4Hz)
555
+ if time() - last_report > 0.25:
556
+ elapsed = time() - start_time
557
+ rate = frames_processed / elapsed if elapsed > 0 else 0
558
+ remaining = total_frames - frames_processed
559
+ eta = remaining / rate if rate > 0 else 0
560
+
561
+ progress_data = {
562
+ "n_processed": frames_processed,
563
+ "n_total": total_frames,
564
+ "rate": round(rate, 1),
565
+ "eta": round(eta, 1),
566
+ }
567
+ print(json.dumps(progress_data), flush=True)
568
+ last_report = time()
569
+
570
+ # Final progress emit to ensure 100% is shown
571
+ elapsed = time() - start_time
572
+ progress_data = {
573
+ "n_processed": total_frames,
574
+ "n_total": total_frames,
575
+ "rate": round(frames_processed / elapsed, 1) if elapsed > 0 else 0,
576
+ "eta": 0,
577
+ }
578
+ print(json.dumps(progress_data), flush=True)
579
+
580
+ def _predict_generator_rich(
581
+ self, total_frames: int
582
+ ) -> Iterator[Dict[str, np.ndarray]]:
583
+ """Generator for normal mode with Rich progress bar.
584
+
585
+ Args:
586
+ total_frames: Total number of frames to process.
587
+
588
+ Yields:
589
+ Dictionaries containing inference results for each frame.
590
+ """
591
+ with Progress(
592
+ "{task.description}",
593
+ BarColumn(),
594
+ "[progress.percentage]{task.percentage:>3.0f}%",
595
+ MofNCompleteColumn(),
596
+ "ETA:",
597
+ TimeRemainingColumn(),
598
+ "Elapsed:",
599
+ TimeElapsedColumn(),
600
+ RateColumn(),
601
+ auto_refresh=False,
602
+ refresh_per_second=4,
603
+ speed_estimate_period=5,
604
+ ) as progress:
605
+ task = progress.add_task("Predicting...", total=total_frames)
606
+ last_report = time()
607
+ done = False
608
+
609
+ while not done:
610
+ imgs, fidxs, vidxs, org_szs, instances, eff_scales, done = (
611
+ self._process_batch()
612
+ )
613
+
614
+ if imgs:
615
+ yield from self._run_inference_on_batch(
616
+ imgs, fidxs, vidxs, org_szs, instances, eff_scales
617
+ )
618
+
619
+ # Advance progress
620
+ num_frames = len(fidxs)
621
+ progress.update(task, advance=num_frames)
622
+
623
+ # Manually refresh progress bar
624
+ if time() - last_report > 0.25:
625
+ progress.refresh()
626
+ last_report = time()
627
+
628
+ self.pipeline.join()
629
+
523
630
  def predict(
524
631
  self,
525
632
  make_labels: bool = True,
sleap_nn/predict.py CHANGED
@@ -74,6 +74,9 @@ def run_inference(
74
74
  frames: Optional[list] = None,
75
75
  crop_size: Optional[int] = None,
76
76
  peak_threshold: Union[float, List[float]] = 0.2,
77
+ filter_overlapping: bool = False,
78
+ filter_overlapping_method: str = "iou",
79
+ filter_overlapping_threshold: float = 0.8,
77
80
  integral_refinement: Optional[str] = "integral",
78
81
  integral_patch_size: int = 5,
79
82
  return_confmaps: bool = False,
@@ -110,6 +113,7 @@ def run_inference(
110
113
  tracking_pre_cull_iou_threshold: float = 0,
111
114
  tracking_clean_instance_count: int = 0,
112
115
  tracking_clean_iou_threshold: float = 0,
116
+ gui: bool = False,
113
117
  ):
114
118
  """Entry point to run inference on trained SLEAP-NN models.
115
119
 
@@ -160,6 +164,15 @@ def run_inference(
160
164
  centroid and centered-instance model, where the first element corresponds
161
165
  to centroid model peak finding threshold and the second element is for
162
166
  centered-instance model peak finding.
167
+ filter_overlapping: (bool) If True, removes overlapping instances after
168
+ inference using greedy NMS. Applied independently of tracking.
169
+ Default: False.
170
+ filter_overlapping_method: (str) Similarity metric for filtering overlapping
171
+ instances. One of "iou" (bounding box) or "oks" (keypoint similarity).
172
+ Default: "iou".
173
+ filter_overlapping_threshold: (float) Similarity threshold for filtering.
174
+ Instances with similarity > threshold are removed (keeping higher-scoring).
175
+ Typical values: 0.3 (aggressive) to 0.8 (permissive). Default: 0.8.
163
176
  integral_refinement: (str) If `None`, returns the grid-aligned peaks with no refinement.
164
177
  If `"integral"`, peaks will be refined with integral regression.
165
178
  Default: `"integral"`.
@@ -250,6 +263,8 @@ def run_inference(
250
263
  tracking_pre_cull_iou_threshold: If non-zero and pre_cull_to_target also set, then use IOU threshold to remove overlapping instances over count *before* tracking. (default: 0)
251
264
  tracking_clean_instance_count: Target number of instances to clean *after* tracking. (default: 0)
252
265
  tracking_clean_iou_threshold: IOU to use when culling instances *after* tracking. (default: 0)
266
+ gui: (bool) If True, outputs JSON progress lines for GUI integration instead
267
+ of Rich progress bars. Default: False.
253
268
 
254
269
  Returns:
255
270
  Returns `sio.Labels` object if `make_labels` is True. Else this function returns
@@ -433,13 +448,6 @@ def run_inference(
433
448
  else "mps" if torch.backends.mps.is_available() else "cpu"
434
449
  )
435
450
 
436
- if integral_refinement is not None and device == "mps": # TODO
437
- # kornia/geometry/transform/imgwarp.py:382: in get_perspective_transform. NotImplementedError: The operator 'aten::_linalg_solve_ex.result' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
438
- logger.info(
439
- "Integral refinement is not supported with MPS accelerator. Setting integral refinement to None."
440
- )
441
- integral_refinement = None
442
-
443
451
  logger.info(f"Using device: {device}")
444
452
 
445
453
  # initializes the inference model
@@ -458,6 +466,9 @@ def run_inference(
458
466
  anchor_part=anchor_part,
459
467
  )
460
468
 
469
+ # Set GUI mode for progress output
470
+ predictor.gui = gui
471
+
461
472
  if (
462
473
  tracking
463
474
  and not isinstance(predictor, BottomUpMultiClassPredictor)
@@ -553,6 +564,20 @@ def run_inference(
553
564
  make_labels=make_labels,
554
565
  )
555
566
 
567
+ # Filter overlapping instances (independent of tracking)
568
+ if filter_overlapping and make_labels:
569
+ from sleap_nn.inference.postprocessing import filter_overlapping_instances
570
+
571
+ output = filter_overlapping_instances(
572
+ output,
573
+ threshold=filter_overlapping_threshold,
574
+ method=filter_overlapping_method,
575
+ )
576
+ logger.info(
577
+ f"Filtered overlapping instances with {filter_overlapping_method.upper()} "
578
+ f"threshold: {filter_overlapping_threshold}"
579
+ )
580
+
556
581
  if tracking:
557
582
  lfs = [x for x in output]
558
583
  if tracking_clean_instance_count > 0:
@@ -607,6 +632,9 @@ def run_inference(
607
632
  # Build inference parameters for provenance
608
633
  inference_params = {
609
634
  "peak_threshold": peak_threshold,
635
+ "filter_overlapping": filter_overlapping,
636
+ "filter_overlapping_method": filter_overlapping_method,
637
+ "filter_overlapping_threshold": filter_overlapping_threshold,
610
638
  "integral_refinement": integral_refinement,
611
639
  "integral_patch_size": integral_patch_size,
612
640
  "batch_size": batch_size,
sleap_nn/train.py CHANGED
@@ -118,6 +118,70 @@ def run_training(
118
118
  logger.info(f"p90 dist: {metrics['distance_metrics']['p90']}")
119
119
  logger.info(f"p50 dist: {metrics['distance_metrics']['p50']}")
120
120
 
121
+ # Log test metrics to wandb summary
122
+ if (
123
+ d_name.startswith("test")
124
+ and trainer.config.trainer_config.use_wandb
125
+ ):
126
+ import wandb
127
+
128
+ if wandb.run is not None:
129
+ summary_metrics = {
130
+ f"eval/{d_name}/mOKS": metrics["mOKS"]["mOKS"],
131
+ f"eval/{d_name}/oks_voc_mAP": metrics["voc_metrics"][
132
+ "oks_voc.mAP"
133
+ ],
134
+ f"eval/{d_name}/oks_voc_mAR": metrics["voc_metrics"][
135
+ "oks_voc.mAR"
136
+ ],
137
+ f"eval/{d_name}/mPCK": metrics["pck_metrics"]["mPCK"],
138
+ f"eval/{d_name}/PCK_5": metrics["pck_metrics"]["PCK@5"],
139
+ f"eval/{d_name}/PCK_10": metrics["pck_metrics"]["PCK@10"],
140
+ f"eval/{d_name}/distance_avg": metrics["distance_metrics"][
141
+ "avg"
142
+ ],
143
+ f"eval/{d_name}/distance_p50": metrics["distance_metrics"][
144
+ "p50"
145
+ ],
146
+ f"eval/{d_name}/distance_p95": metrics["distance_metrics"][
147
+ "p95"
148
+ ],
149
+ f"eval/{d_name}/distance_p99": metrics["distance_metrics"][
150
+ "p99"
151
+ ],
152
+ f"eval/{d_name}/visibility_precision": metrics[
153
+ "visibility_metrics"
154
+ ]["precision"],
155
+ f"eval/{d_name}/visibility_recall": metrics[
156
+ "visibility_metrics"
157
+ ]["recall"],
158
+ }
159
+ for key, value in summary_metrics.items():
160
+ wandb.run.summary[key] = value
161
+
162
+ # Finish wandb run and cleanup after all evaluation is complete
163
+ if trainer.config.trainer_config.use_wandb:
164
+ import wandb
165
+ import shutil
166
+
167
+ if wandb.run is not None:
168
+ wandb.finish()
169
+
170
+ # Delete local wandb logs if configured
171
+ wandb_config = trainer.config.trainer_config.wandb
172
+ should_delete_wandb_logs = wandb_config.delete_local_logs is True or (
173
+ wandb_config.delete_local_logs is None
174
+ and wandb_config.wandb_mode != "offline"
175
+ )
176
+ if should_delete_wandb_logs:
177
+ wandb_dir = run_path / "wandb"
178
+ if wandb_dir.exists():
179
+ logger.info(
180
+ f"Deleting local wandb logs at {wandb_dir}... "
181
+ "(set trainer_config.wandb.delete_local_logs=false to disable)"
182
+ )
183
+ shutil.rmtree(wandb_dir, ignore_errors=True)
184
+
121
185
 
122
186
  def train(
123
187
  train_labels_path: Optional[List[str]] = None,