sleap-nn 0.1.0a1__py3-none-any.whl → 0.1.0a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +1 -1
- sleap_nn/cli.py +36 -0
- sleap_nn/config/trainer_config.py +18 -0
- sleap_nn/evaluation.py +81 -22
- sleap_nn/export/__init__.py +21 -0
- sleap_nn/export/cli.py +1778 -0
- sleap_nn/export/exporters/__init__.py +51 -0
- sleap_nn/export/exporters/onnx_exporter.py +80 -0
- sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
- sleap_nn/export/metadata.py +225 -0
- sleap_nn/export/predictors/__init__.py +63 -0
- sleap_nn/export/predictors/base.py +22 -0
- sleap_nn/export/predictors/onnx.py +154 -0
- sleap_nn/export/predictors/tensorrt.py +312 -0
- sleap_nn/export/utils.py +307 -0
- sleap_nn/export/wrappers/__init__.py +25 -0
- sleap_nn/export/wrappers/base.py +96 -0
- sleap_nn/export/wrappers/bottomup.py +243 -0
- sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
- sleap_nn/export/wrappers/centered_instance.py +56 -0
- sleap_nn/export/wrappers/centroid.py +58 -0
- sleap_nn/export/wrappers/single_instance.py +83 -0
- sleap_nn/export/wrappers/topdown.py +180 -0
- sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
- sleap_nn/inference/bottomup.py +86 -20
- sleap_nn/inference/postprocessing.py +284 -0
- sleap_nn/predict.py +29 -0
- sleap_nn/train.py +64 -0
- sleap_nn/training/callbacks.py +324 -8
- sleap_nn/training/lightning_modules.py +542 -32
- sleap_nn/training/model_trainer.py +48 -57
- {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/METADATA +13 -2
- {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/RECORD +37 -16
- {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/WHEEL +0 -0
- {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,284 @@
|
|
|
1
|
+
"""Inference-level postprocessing filters for pose predictions.
|
|
2
|
+
|
|
3
|
+
This module provides filters that run after model inference but before tracking.
|
|
4
|
+
These filters are independent of tracking configuration and can be used standalone.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import List, Literal
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
import sleap_io as sio
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def filter_overlapping_instances(
|
|
14
|
+
labels: sio.Labels,
|
|
15
|
+
threshold: float = 0.8,
|
|
16
|
+
method: Literal["iou", "oks"] = "iou",
|
|
17
|
+
) -> sio.Labels:
|
|
18
|
+
"""Filter overlapping instances using greedy non-maximum suppression.
|
|
19
|
+
|
|
20
|
+
Removes duplicate/overlapping instances by applying greedy NMS based on
|
|
21
|
+
either bounding box IOU or Object Keypoint Similarity (OKS). When two
|
|
22
|
+
instances overlap above the threshold, the lower-scoring one is removed.
|
|
23
|
+
|
|
24
|
+
This filter runs independently of tracking and can be used to clean up
|
|
25
|
+
model outputs before saving or further processing.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
labels: Labels object with predicted instances to filter.
|
|
29
|
+
threshold: Similarity threshold for considering instances as overlapping.
|
|
30
|
+
Instances with similarity > threshold are candidates for removal.
|
|
31
|
+
Lower values are more aggressive (remove more).
|
|
32
|
+
Typical values: 0.3 (aggressive) to 0.8 (permissive).
|
|
33
|
+
method: Similarity metric to use for comparing instances.
|
|
34
|
+
"iou": Bounding box intersection-over-union.
|
|
35
|
+
"oks": Object Keypoint Similarity (pose-based).
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
The input Labels object with overlapping instances removed.
|
|
39
|
+
Modification is done in place, but the object is also returned
|
|
40
|
+
for convenience.
|
|
41
|
+
|
|
42
|
+
Example:
|
|
43
|
+
>>> # Filter instances with >80% bounding box overlap
|
|
44
|
+
>>> labels = filter_overlapping_instances(labels, threshold=0.8, method="iou")
|
|
45
|
+
>>> # Filter using OKS similarity
|
|
46
|
+
>>> labels = filter_overlapping_instances(labels, threshold=0.5, method="oks")
|
|
47
|
+
|
|
48
|
+
Note:
|
|
49
|
+
- Only affects frames with 2+ predicted instances
|
|
50
|
+
- Uses instance.score for ranking; higher scores are preferred
|
|
51
|
+
- For IOU: bounding boxes computed from non-NaN keypoints
|
|
52
|
+
- For OKS: uses standard COCO OKS formula with bbox-derived scale
|
|
53
|
+
"""
|
|
54
|
+
for lf in labels.labeled_frames:
|
|
55
|
+
if len(lf.instances) <= 1:
|
|
56
|
+
continue
|
|
57
|
+
|
|
58
|
+
# Separate predicted instances (have scores) from other instances
|
|
59
|
+
predicted = []
|
|
60
|
+
other = []
|
|
61
|
+
for inst in lf.instances:
|
|
62
|
+
if isinstance(inst, sio.PredictedInstance):
|
|
63
|
+
predicted.append(inst)
|
|
64
|
+
else:
|
|
65
|
+
other.append(inst)
|
|
66
|
+
|
|
67
|
+
# Only filter predicted instances
|
|
68
|
+
if len(predicted) <= 1:
|
|
69
|
+
continue
|
|
70
|
+
|
|
71
|
+
# Get scores
|
|
72
|
+
scores = np.array([_instance_score(inst) for inst in predicted])
|
|
73
|
+
|
|
74
|
+
# Apply greedy NMS with selected method
|
|
75
|
+
if method == "iou":
|
|
76
|
+
bboxes = np.array([_instance_bbox(inst) for inst in predicted])
|
|
77
|
+
keep_indices = _nms_greedy_iou(bboxes, scores, threshold)
|
|
78
|
+
elif method == "oks":
|
|
79
|
+
points = [inst.numpy() for inst in predicted]
|
|
80
|
+
keep_indices = _nms_greedy_oks(points, scores, threshold)
|
|
81
|
+
else:
|
|
82
|
+
raise ValueError(f"Unknown method: {method}. Use 'iou' or 'oks'.")
|
|
83
|
+
|
|
84
|
+
# Reconstruct instance list: kept predicted + other instances
|
|
85
|
+
kept_predicted = [predicted[i] for i in keep_indices]
|
|
86
|
+
lf.instances = kept_predicted + other
|
|
87
|
+
|
|
88
|
+
return labels
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _instance_bbox(instance: sio.PredictedInstance) -> np.ndarray:
|
|
92
|
+
"""Compute axis-aligned bounding box from instance keypoints.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
instance: Instance with keypoints.
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
Bounding box as [xmin, ymin, xmax, ymax].
|
|
99
|
+
Returns [0, 0, 0, 0] if no valid keypoints.
|
|
100
|
+
"""
|
|
101
|
+
pts = instance.numpy() # (n_nodes, 2)
|
|
102
|
+
valid = ~np.isnan(pts).any(axis=1)
|
|
103
|
+
|
|
104
|
+
if not valid.any():
|
|
105
|
+
return np.array([0.0, 0.0, 0.0, 0.0])
|
|
106
|
+
|
|
107
|
+
pts = pts[valid]
|
|
108
|
+
return np.array(
|
|
109
|
+
[pts[:, 0].min(), pts[:, 1].min(), pts[:, 0].max(), pts[:, 1].max()]
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _instance_score(instance: sio.PredictedInstance) -> float:
|
|
114
|
+
"""Get instance confidence score.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
instance: Predicted instance.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
Instance score, or 1.0 if not available.
|
|
121
|
+
"""
|
|
122
|
+
return getattr(instance, "score", 1.0)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _nms_greedy_iou(
|
|
126
|
+
bboxes: np.ndarray,
|
|
127
|
+
scores: np.ndarray,
|
|
128
|
+
threshold: float,
|
|
129
|
+
) -> List[int]:
|
|
130
|
+
"""Apply greedy NMS using bounding box IOU.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
bboxes: Bounding boxes of shape (N, 4) as [xmin, ymin, xmax, ymax].
|
|
134
|
+
scores: Confidence scores of shape (N,).
|
|
135
|
+
threshold: IOU threshold for suppression.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
List of indices to keep, in order of decreasing score.
|
|
139
|
+
"""
|
|
140
|
+
if len(bboxes) == 0:
|
|
141
|
+
return []
|
|
142
|
+
|
|
143
|
+
# Sort by score descending
|
|
144
|
+
order = scores.argsort()[::-1].tolist()
|
|
145
|
+
|
|
146
|
+
keep = []
|
|
147
|
+
while order:
|
|
148
|
+
# Take highest scoring remaining instance
|
|
149
|
+
i = order.pop(0)
|
|
150
|
+
keep.append(i)
|
|
151
|
+
|
|
152
|
+
if not order:
|
|
153
|
+
break
|
|
154
|
+
|
|
155
|
+
# Compute IOU with all remaining instances
|
|
156
|
+
remaining_indices = np.array(order)
|
|
157
|
+
similarities = _compute_iou_one_to_many(bboxes[i], bboxes[remaining_indices])
|
|
158
|
+
|
|
159
|
+
# Keep only instances with similarity <= threshold
|
|
160
|
+
mask = similarities <= threshold
|
|
161
|
+
order = [order[j] for j in range(len(order)) if mask[j]]
|
|
162
|
+
|
|
163
|
+
return keep
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _nms_greedy_oks(
|
|
167
|
+
points_list: List[np.ndarray],
|
|
168
|
+
scores: np.ndarray,
|
|
169
|
+
threshold: float,
|
|
170
|
+
) -> List[int]:
|
|
171
|
+
"""Apply greedy NMS using Object Keypoint Similarity (OKS).
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
points_list: List of keypoint arrays, each of shape (n_nodes, 2).
|
|
175
|
+
scores: Confidence scores of shape (N,).
|
|
176
|
+
threshold: OKS threshold for suppression.
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
List of indices to keep, in order of decreasing score.
|
|
180
|
+
"""
|
|
181
|
+
if len(points_list) == 0:
|
|
182
|
+
return []
|
|
183
|
+
|
|
184
|
+
# Sort by score descending
|
|
185
|
+
order = scores.argsort()[::-1].tolist()
|
|
186
|
+
|
|
187
|
+
keep = []
|
|
188
|
+
while order:
|
|
189
|
+
# Take highest scoring remaining instance
|
|
190
|
+
i = order.pop(0)
|
|
191
|
+
keep.append(i)
|
|
192
|
+
|
|
193
|
+
if not order:
|
|
194
|
+
break
|
|
195
|
+
|
|
196
|
+
# Compute OKS with all remaining instances
|
|
197
|
+
similarities = np.array(
|
|
198
|
+
[_compute_oks(points_list[i], points_list[j]) for j in order]
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
# Keep only instances with similarity <= threshold
|
|
202
|
+
mask = similarities <= threshold
|
|
203
|
+
order = [order[j] for j in range(len(order)) if mask[j]]
|
|
204
|
+
|
|
205
|
+
return keep
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def _compute_iou_one_to_many(box: np.ndarray, boxes: np.ndarray) -> np.ndarray:
|
|
209
|
+
"""Compute IOU between one box and multiple boxes.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
box: Single box of shape (4,) as [xmin, ymin, xmax, ymax].
|
|
213
|
+
boxes: Multiple boxes of shape (N, 4).
|
|
214
|
+
|
|
215
|
+
Returns:
|
|
216
|
+
IOU values of shape (N,).
|
|
217
|
+
"""
|
|
218
|
+
# Intersection coordinates
|
|
219
|
+
inter_xmin = np.maximum(box[0], boxes[:, 0])
|
|
220
|
+
inter_ymin = np.maximum(box[1], boxes[:, 1])
|
|
221
|
+
inter_xmax = np.minimum(box[2], boxes[:, 2])
|
|
222
|
+
inter_ymax = np.minimum(box[3], boxes[:, 3])
|
|
223
|
+
|
|
224
|
+
# Intersection area (0 if no overlap)
|
|
225
|
+
inter_w = np.maximum(0.0, inter_xmax - inter_xmin)
|
|
226
|
+
inter_h = np.maximum(0.0, inter_ymax - inter_ymin)
|
|
227
|
+
inter_area = inter_w * inter_h
|
|
228
|
+
|
|
229
|
+
# Individual areas
|
|
230
|
+
area_a = (box[2] - box[0]) * (box[3] - box[1])
|
|
231
|
+
area_b = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
|
232
|
+
|
|
233
|
+
# Union area
|
|
234
|
+
union_area = area_a + area_b - inter_area
|
|
235
|
+
|
|
236
|
+
# IOU (avoid division by zero)
|
|
237
|
+
return np.where(union_area > 0, inter_area / union_area, 0.0)
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def _compute_oks(
|
|
241
|
+
points_a: np.ndarray,
|
|
242
|
+
points_b: np.ndarray,
|
|
243
|
+
kappa: float = 0.1,
|
|
244
|
+
) -> float:
|
|
245
|
+
"""Compute Object Keypoint Similarity (OKS) between two instances.
|
|
246
|
+
|
|
247
|
+
Uses a simplified OKS formula where all keypoints have equal weight
|
|
248
|
+
and scale is derived from the bounding box of the reference instance.
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
points_a: Keypoints of first instance, shape (n_nodes, 2).
|
|
252
|
+
points_b: Keypoints of second instance, shape (n_nodes, 2).
|
|
253
|
+
kappa: Per-keypoint constant controlling falloff. Default 0.1.
|
|
254
|
+
|
|
255
|
+
Returns:
|
|
256
|
+
OKS value in [0, 1]. Higher means more similar.
|
|
257
|
+
"""
|
|
258
|
+
# Find valid keypoints (present in both instances)
|
|
259
|
+
valid_a = ~np.isnan(points_a).any(axis=1)
|
|
260
|
+
valid_b = ~np.isnan(points_b).any(axis=1)
|
|
261
|
+
valid = valid_a & valid_b
|
|
262
|
+
|
|
263
|
+
if not valid.any():
|
|
264
|
+
return 0.0
|
|
265
|
+
|
|
266
|
+
# Compute scale from bounding box area of instance A
|
|
267
|
+
pts_a_valid = points_a[valid_a]
|
|
268
|
+
if len(pts_a_valid) < 2:
|
|
269
|
+
return 0.0
|
|
270
|
+
|
|
271
|
+
bbox_w = pts_a_valid[:, 0].max() - pts_a_valid[:, 0].min()
|
|
272
|
+
bbox_h = pts_a_valid[:, 1].max() - pts_a_valid[:, 1].min()
|
|
273
|
+
scale_sq = bbox_w * bbox_h
|
|
274
|
+
|
|
275
|
+
if scale_sq <= 0:
|
|
276
|
+
return 0.0
|
|
277
|
+
|
|
278
|
+
# Compute squared distances for valid keypoints
|
|
279
|
+
d_sq = np.sum((points_a[valid] - points_b[valid]) ** 2, axis=1)
|
|
280
|
+
|
|
281
|
+
# OKS formula: mean of exp(-d^2 / (2 * s^2 * k^2))
|
|
282
|
+
oks_per_kpt = np.exp(-d_sq / (2 * scale_sq * kappa**2))
|
|
283
|
+
|
|
284
|
+
return float(np.mean(oks_per_kpt))
|
sleap_nn/predict.py
CHANGED
|
@@ -74,6 +74,9 @@ def run_inference(
|
|
|
74
74
|
frames: Optional[list] = None,
|
|
75
75
|
crop_size: Optional[int] = None,
|
|
76
76
|
peak_threshold: Union[float, List[float]] = 0.2,
|
|
77
|
+
filter_overlapping: bool = False,
|
|
78
|
+
filter_overlapping_method: str = "iou",
|
|
79
|
+
filter_overlapping_threshold: float = 0.8,
|
|
77
80
|
integral_refinement: Optional[str] = "integral",
|
|
78
81
|
integral_patch_size: int = 5,
|
|
79
82
|
return_confmaps: bool = False,
|
|
@@ -160,6 +163,15 @@ def run_inference(
|
|
|
160
163
|
centroid and centered-instance model, where the first element corresponds
|
|
161
164
|
to centroid model peak finding threshold and the second element is for
|
|
162
165
|
centered-instance model peak finding.
|
|
166
|
+
filter_overlapping: (bool) If True, removes overlapping instances after
|
|
167
|
+
inference using greedy NMS. Applied independently of tracking.
|
|
168
|
+
Default: False.
|
|
169
|
+
filter_overlapping_method: (str) Similarity metric for filtering overlapping
|
|
170
|
+
instances. One of "iou" (bounding box) or "oks" (keypoint similarity).
|
|
171
|
+
Default: "iou".
|
|
172
|
+
filter_overlapping_threshold: (float) Similarity threshold for filtering.
|
|
173
|
+
Instances with similarity > threshold are removed (keeping higher-scoring).
|
|
174
|
+
Typical values: 0.3 (aggressive) to 0.8 (permissive). Default: 0.8.
|
|
163
175
|
integral_refinement: (str) If `None`, returns the grid-aligned peaks with no refinement.
|
|
164
176
|
If `"integral"`, peaks will be refined with integral regression.
|
|
165
177
|
Default: `"integral"`.
|
|
@@ -553,6 +565,20 @@ def run_inference(
|
|
|
553
565
|
make_labels=make_labels,
|
|
554
566
|
)
|
|
555
567
|
|
|
568
|
+
# Filter overlapping instances (independent of tracking)
|
|
569
|
+
if filter_overlapping and make_labels:
|
|
570
|
+
from sleap_nn.inference.postprocessing import filter_overlapping_instances
|
|
571
|
+
|
|
572
|
+
output = filter_overlapping_instances(
|
|
573
|
+
output,
|
|
574
|
+
threshold=filter_overlapping_threshold,
|
|
575
|
+
method=filter_overlapping_method,
|
|
576
|
+
)
|
|
577
|
+
logger.info(
|
|
578
|
+
f"Filtered overlapping instances with {filter_overlapping_method.upper()} "
|
|
579
|
+
f"threshold: {filter_overlapping_threshold}"
|
|
580
|
+
)
|
|
581
|
+
|
|
556
582
|
if tracking:
|
|
557
583
|
lfs = [x for x in output]
|
|
558
584
|
if tracking_clean_instance_count > 0:
|
|
@@ -607,6 +633,9 @@ def run_inference(
|
|
|
607
633
|
# Build inference parameters for provenance
|
|
608
634
|
inference_params = {
|
|
609
635
|
"peak_threshold": peak_threshold,
|
|
636
|
+
"filter_overlapping": filter_overlapping,
|
|
637
|
+
"filter_overlapping_method": filter_overlapping_method,
|
|
638
|
+
"filter_overlapping_threshold": filter_overlapping_threshold,
|
|
610
639
|
"integral_refinement": integral_refinement,
|
|
611
640
|
"integral_patch_size": integral_patch_size,
|
|
612
641
|
"batch_size": batch_size,
|
sleap_nn/train.py
CHANGED
|
@@ -118,6 +118,70 @@ def run_training(
|
|
|
118
118
|
logger.info(f"p90 dist: {metrics['distance_metrics']['p90']}")
|
|
119
119
|
logger.info(f"p50 dist: {metrics['distance_metrics']['p50']}")
|
|
120
120
|
|
|
121
|
+
# Log test metrics to wandb summary
|
|
122
|
+
if (
|
|
123
|
+
d_name.startswith("test")
|
|
124
|
+
and trainer.config.trainer_config.use_wandb
|
|
125
|
+
):
|
|
126
|
+
import wandb
|
|
127
|
+
|
|
128
|
+
if wandb.run is not None:
|
|
129
|
+
summary_metrics = {
|
|
130
|
+
f"eval/{d_name}/mOKS": metrics["mOKS"]["mOKS"],
|
|
131
|
+
f"eval/{d_name}/oks_voc_mAP": metrics["voc_metrics"][
|
|
132
|
+
"oks_voc.mAP"
|
|
133
|
+
],
|
|
134
|
+
f"eval/{d_name}/oks_voc_mAR": metrics["voc_metrics"][
|
|
135
|
+
"oks_voc.mAR"
|
|
136
|
+
],
|
|
137
|
+
f"eval/{d_name}/mPCK": metrics["pck_metrics"]["mPCK"],
|
|
138
|
+
f"eval/{d_name}/PCK_5": metrics["pck_metrics"]["PCK@5"],
|
|
139
|
+
f"eval/{d_name}/PCK_10": metrics["pck_metrics"]["PCK@10"],
|
|
140
|
+
f"eval/{d_name}/distance_avg": metrics["distance_metrics"][
|
|
141
|
+
"avg"
|
|
142
|
+
],
|
|
143
|
+
f"eval/{d_name}/distance_p50": metrics["distance_metrics"][
|
|
144
|
+
"p50"
|
|
145
|
+
],
|
|
146
|
+
f"eval/{d_name}/distance_p95": metrics["distance_metrics"][
|
|
147
|
+
"p95"
|
|
148
|
+
],
|
|
149
|
+
f"eval/{d_name}/distance_p99": metrics["distance_metrics"][
|
|
150
|
+
"p99"
|
|
151
|
+
],
|
|
152
|
+
f"eval/{d_name}/visibility_precision": metrics[
|
|
153
|
+
"visibility_metrics"
|
|
154
|
+
]["precision"],
|
|
155
|
+
f"eval/{d_name}/visibility_recall": metrics[
|
|
156
|
+
"visibility_metrics"
|
|
157
|
+
]["recall"],
|
|
158
|
+
}
|
|
159
|
+
for key, value in summary_metrics.items():
|
|
160
|
+
wandb.run.summary[key] = value
|
|
161
|
+
|
|
162
|
+
# Finish wandb run and cleanup after all evaluation is complete
|
|
163
|
+
if trainer.config.trainer_config.use_wandb:
|
|
164
|
+
import wandb
|
|
165
|
+
import shutil
|
|
166
|
+
|
|
167
|
+
if wandb.run is not None:
|
|
168
|
+
wandb.finish()
|
|
169
|
+
|
|
170
|
+
# Delete local wandb logs if configured
|
|
171
|
+
wandb_config = trainer.config.trainer_config.wandb
|
|
172
|
+
should_delete_wandb_logs = wandb_config.delete_local_logs is True or (
|
|
173
|
+
wandb_config.delete_local_logs is None
|
|
174
|
+
and wandb_config.wandb_mode != "offline"
|
|
175
|
+
)
|
|
176
|
+
if should_delete_wandb_logs:
|
|
177
|
+
wandb_dir = run_path / "wandb"
|
|
178
|
+
if wandb_dir.exists():
|
|
179
|
+
logger.info(
|
|
180
|
+
f"Deleting local wandb logs at {wandb_dir}... "
|
|
181
|
+
"(set trainer_config.wandb.delete_local_logs=false to disable)"
|
|
182
|
+
)
|
|
183
|
+
shutil.rmtree(wandb_dir, ignore_errors=True)
|
|
184
|
+
|
|
121
185
|
|
|
122
186
|
def train(
|
|
123
187
|
train_labels_path: Optional[List[str]] = None,
|