opensportslib 0.0.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opensportslib/__init__.py +18 -0
- opensportslib/apis/__init__.py +21 -0
- opensportslib/apis/classification.py +361 -0
- opensportslib/apis/localization.py +228 -0
- opensportslib/config/classification.yaml +104 -0
- opensportslib/config/classification_tracking.yaml +103 -0
- opensportslib/config/graph_tracking_classification/avgpool.yaml +79 -0
- opensportslib/config/graph_tracking_classification/gin.yaml +79 -0
- opensportslib/config/graph_tracking_classification/graphconv.yaml +79 -0
- opensportslib/config/graph_tracking_classification/graphsage.yaml +79 -0
- opensportslib/config/graph_tracking_classification/maxpool.yaml +79 -0
- opensportslib/config/graph_tracking_classification/noedges.yaml +79 -0
- opensportslib/config/localization.yaml +132 -0
- opensportslib/config/sngar_frames.yaml +98 -0
- opensportslib/core/__init__.py +0 -0
- opensportslib/core/loss/__init__.py +0 -0
- opensportslib/core/loss/builder.py +40 -0
- opensportslib/core/loss/calf.py +258 -0
- opensportslib/core/loss/ce.py +23 -0
- opensportslib/core/loss/combine.py +42 -0
- opensportslib/core/loss/nll.py +25 -0
- opensportslib/core/optimizer/__init__.py +0 -0
- opensportslib/core/optimizer/builder.py +38 -0
- opensportslib/core/sampler/weighted_sampler.py +104 -0
- opensportslib/core/scheduler/__init__.py +0 -0
- opensportslib/core/scheduler/builder.py +77 -0
- opensportslib/core/trainer/__init__.py +0 -0
- opensportslib/core/trainer/classification_trainer.py +1131 -0
- opensportslib/core/trainer/localization_trainer.py +1009 -0
- opensportslib/core/utils/checkpoint.py +238 -0
- opensportslib/core/utils/config.py +199 -0
- opensportslib/core/utils/data.py +85 -0
- opensportslib/core/utils/ddp.py +77 -0
- opensportslib/core/utils/default_args.py +110 -0
- opensportslib/core/utils/load_annotations.py +485 -0
- opensportslib/core/utils/seed.py +26 -0
- opensportslib/core/utils/video_processing.py +389 -0
- opensportslib/core/utils/wandb.py +110 -0
- opensportslib/datasets/__init__.py +0 -0
- opensportslib/datasets/builder.py +42 -0
- opensportslib/datasets/classification_dataset.py +582 -0
- opensportslib/datasets/localization_dataset.py +813 -0
- opensportslib/datasets/utils/__init__.py +15 -0
- opensportslib/datasets/utils/tracking.py +615 -0
- opensportslib/metrics/classification_metric.py +176 -0
- opensportslib/metrics/localization_metric.py +1482 -0
- opensportslib/models/__init__.py +0 -0
- opensportslib/models/backbones/builder.py +590 -0
- opensportslib/models/base/e2e.py +252 -0
- opensportslib/models/base/tracking.py +73 -0
- opensportslib/models/base/vars.py +29 -0
- opensportslib/models/base/video.py +130 -0
- opensportslib/models/base/video_mae.py +60 -0
- opensportslib/models/builder.py +43 -0
- opensportslib/models/heads/builder.py +266 -0
- opensportslib/models/neck/builder.py +210 -0
- opensportslib/models/utils/common.py +176 -0
- opensportslib/models/utils/impl/__init__.py +0 -0
- opensportslib/models/utils/impl/asformer.py +390 -0
- opensportslib/models/utils/impl/calf.py +74 -0
- opensportslib/models/utils/impl/gsm.py +112 -0
- opensportslib/models/utils/impl/gtad.py +347 -0
- opensportslib/models/utils/impl/tsm.py +123 -0
- opensportslib/models/utils/litebase.py +59 -0
- opensportslib/models/utils/modules.py +120 -0
- opensportslib/models/utils/shift.py +135 -0
- opensportslib/models/utils/utils.py +276 -0
- opensportslib-0.0.1.dev2.dist-info/METADATA +566 -0
- opensportslib-0.0.1.dev2.dist-info/RECORD +73 -0
- opensportslib-0.0.1.dev2.dist-info/WHEEL +5 -0
- opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE +661 -0
- opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE-COMMERCIAL +5 -0
- opensportslib-0.0.1.dev2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from evaluate import load
|
|
3
|
+
from sklearn.metrics import balanced_accuracy_score
|
|
4
|
+
|
|
5
|
+
# Load HuggingFace metrics
|
|
6
|
+
accuracy_metric = load("accuracy")
|
|
7
|
+
f1_metric = load("f1")
|
|
8
|
+
precision_metric = load("precision")
|
|
9
|
+
recall_metric = load("recall")
|
|
10
|
+
|
|
11
|
+
def process_preds_labels(eval_pred, top_k=None):
|
|
12
|
+
"""
|
|
13
|
+
Handles tuple logits, one-hot labels, and optionally returns top-k predictions.
|
|
14
|
+
"""
|
|
15
|
+
logits, labels = eval_pred
|
|
16
|
+
|
|
17
|
+
# Handle tuple outputs (some HF models return tuple)
|
|
18
|
+
if isinstance(logits, tuple):
|
|
19
|
+
logits = logits[0]
|
|
20
|
+
|
|
21
|
+
logits = np.asarray(logits)
|
|
22
|
+
labels = np.asarray(labels)
|
|
23
|
+
|
|
24
|
+
# Convert one-hot labels to class indices
|
|
25
|
+
if labels.ndim > 1:
|
|
26
|
+
labels = np.argmax(labels, axis=-1)
|
|
27
|
+
|
|
28
|
+
# # Ensure logits are 2D for top-k computations
|
|
29
|
+
# if logits.ndim == 1:
|
|
30
|
+
# logits = logits.reshape(1, -1)
|
|
31
|
+
|
|
32
|
+
# Predicted classes
|
|
33
|
+
preds = np.argmax(logits, axis=-1)
|
|
34
|
+
|
|
35
|
+
# Top-k predictions for top-k accuracy
|
|
36
|
+
if top_k is not None:
|
|
37
|
+
topk_preds = np.argsort(logits, axis=-1)[:, -top_k:]
|
|
38
|
+
else:
|
|
39
|
+
topk_preds = None
|
|
40
|
+
|
|
41
|
+
return preds, labels, topk_preds
|
|
42
|
+
|
|
43
|
+
def compute_classification_metrics(eval_pred, top_k=None, mode="logits"):
|
|
44
|
+
"""
|
|
45
|
+
Compute accuracy, F1, precision, recall, and optionally top-k accuracy.
|
|
46
|
+
Returns a dictionary for HF Trainer.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
metrics = {}
|
|
50
|
+
if mode=="labels":
|
|
51
|
+
preds, labels = eval_pred
|
|
52
|
+
preds = np.array(preds)
|
|
53
|
+
labels = np.array(labels)
|
|
54
|
+
else:
|
|
55
|
+
preds, labels, topk_preds = process_preds_labels(eval_pred, top_k)
|
|
56
|
+
|
|
57
|
+
# Top-k accuracy
|
|
58
|
+
if top_k is not None and topk_preds is not None:
|
|
59
|
+
topk_correct = sum([labels[i] in topk_preds[i] for i in range(len(labels))])
|
|
60
|
+
metrics[f"top_{top_k}_accuracy"] = topk_correct / len(labels)
|
|
61
|
+
|
|
62
|
+
# Accuracy
|
|
63
|
+
metrics["accuracy"] = accuracy_metric.compute(predictions=preds, references=labels)["accuracy"]
|
|
64
|
+
|
|
65
|
+
# Balanced accuracy
|
|
66
|
+
metrics["balanced_accuracy"] = balanced_accuracy_score(labels, preds)
|
|
67
|
+
|
|
68
|
+
# F1 (macro)
|
|
69
|
+
metrics["f1"] = f1_metric.compute(predictions=preds, references=labels, average="macro")["f1"]
|
|
70
|
+
|
|
71
|
+
# Precision
|
|
72
|
+
metrics["precision"] = precision_metric.compute(predictions=preds, references=labels, average="macro")["precision"]
|
|
73
|
+
|
|
74
|
+
# Recall
|
|
75
|
+
metrics["recall"] = recall_metric.compute(predictions=preds, references=labels, average="macro")["recall"]
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
return metrics
|
|
79
|
+
|
|
80
|
+
def compute_detailed_classification_metrics(all_logits, all_labels, class_names, save_dir, set_name):
|
|
81
|
+
from sklearn.metrics import confusion_matrix, classification_report, balanced_accuracy_score, f1_score
|
|
82
|
+
|
|
83
|
+
preds = np.argmax(all_logits, axis=-1)
|
|
84
|
+
|
|
85
|
+
sorted_class_names = sorted(class_names.values())
|
|
86
|
+
name_to_sorted_idx = {name: i for i, name in enumerate(sorted_class_names)}
|
|
87
|
+
idx_to_name = class_names
|
|
88
|
+
|
|
89
|
+
sorted_labels = np.array([name_to_sorted_idx[idx_to_name[l]] for l in all_labels])
|
|
90
|
+
sorted_preds = np.array([name_to_sorted_idx[idx_to_name[p]] for p in preds])
|
|
91
|
+
|
|
92
|
+
all_class_labels = list(range(len(sorted_class_names)))
|
|
93
|
+
|
|
94
|
+
cm = confusion_matrix(sorted_labels, sorted_preds, labels=all_class_labels)
|
|
95
|
+
per_class_accuracy = np.diag(cm) / np.maximum(cm.sum(axis=1), 1) * 100
|
|
96
|
+
balanced_acc = balanced_accuracy_score(sorted_labels, sorted_preds) * 100
|
|
97
|
+
per_class_f1 = f1_score(sorted_labels, sorted_preds, labels=all_class_labels, average=None, zero_division=0) * 100
|
|
98
|
+
macro_f1 = f1_score(sorted_labels, sorted_preds, labels=all_class_labels, average="macro", zero_division=0) * 100
|
|
99
|
+
|
|
100
|
+
import matplotlib
|
|
101
|
+
matplotlib.use("Agg")
|
|
102
|
+
import matplotlib.pyplot as plt
|
|
103
|
+
import seaborn as sns
|
|
104
|
+
import os
|
|
105
|
+
|
|
106
|
+
plt.figure(figsize=(12, 10))
|
|
107
|
+
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
|
|
108
|
+
xticklabels=sorted_class_names, yticklabels=sorted_class_names)
|
|
109
|
+
plt.title(f'Confusion Matrix ({set_name})')
|
|
110
|
+
plt.ylabel('True Label')
|
|
111
|
+
plt.xlabel('Predicted Label')
|
|
112
|
+
plt.xticks(rotation=45, ha='right')
|
|
113
|
+
plt.tight_layout()
|
|
114
|
+
|
|
115
|
+
plots_dir = os.path.join(save_dir, 'plots')
|
|
116
|
+
os.makedirs(plots_dir, exist_ok=True)
|
|
117
|
+
plt.savefig(os.path.join(plots_dir, f'confusion_matrix_{set_name}.png'), dpi=300, bbox_inches='tight')
|
|
118
|
+
plt.close()
|
|
119
|
+
|
|
120
|
+
results_dir = os.path.join(save_dir, 'results')
|
|
121
|
+
os.makedirs(results_dir, exist_ok=True)
|
|
122
|
+
|
|
123
|
+
report_path = os.path.join(results_dir, f'{set_name}_detailed_metrics.txt')
|
|
124
|
+
with open(report_path, 'w') as f:
|
|
125
|
+
f.write(f"Balanced Accuracy: {balanced_acc:.2f}%\n")
|
|
126
|
+
f.write(f"Macro F1: {macro_f1:.2f}%\n\n")
|
|
127
|
+
f.write(f"{'Class':<30} {'Accuracy':>10} {'F1':>10} {'Samples':>10}\n")
|
|
128
|
+
f.write("-" * 65 + "\n")
|
|
129
|
+
for i, class_name in enumerate(sorted_class_names):
|
|
130
|
+
num_samples = int(cm[i].sum())
|
|
131
|
+
f.write(f"{class_name:<30} {per_class_accuracy[i]:>9.2f}% {per_class_f1[i]:>9.2f}% {num_samples:>10}\n")
|
|
132
|
+
f.write("-" * 65 + "\n\n")
|
|
133
|
+
f.write("Classification Report:\n\n")
|
|
134
|
+
f.write(classification_report(
|
|
135
|
+
sorted_labels, sorted_preds,
|
|
136
|
+
labels=all_class_labels,
|
|
137
|
+
target_names=sorted_class_names,
|
|
138
|
+
zero_division=0
|
|
139
|
+
))
|
|
140
|
+
f.write("\n" + "-" * 65 + "\n\n")
|
|
141
|
+
f.write("Confusion Matrix:\n\n")
|
|
142
|
+
f.write(f"{cm}\n")
|
|
143
|
+
|
|
144
|
+
tsv_path = os.path.join(results_dir, f'{set_name}_results.tsv')
|
|
145
|
+
with open(tsv_path, 'w') as f:
|
|
146
|
+
header = "metric\t" + "\t".join(sorted_class_names) + "\toverall"
|
|
147
|
+
f.write(header + "\n")
|
|
148
|
+
|
|
149
|
+
acc_row = "accuracy\t" + "\t".join(f"{per_class_accuracy[i]:.2f}" for i in range(len(sorted_class_names))) + f"\t{balanced_acc:.2f}"
|
|
150
|
+
f.write(acc_row + "\n")
|
|
151
|
+
|
|
152
|
+
f1_row = "f1\t" + "\t".join(f"{per_class_f1[i]:.2f}" for i in range(len(sorted_class_names))) + f"\t{macro_f1:.2f}"
|
|
153
|
+
f.write(f1_row + "\n")
|
|
154
|
+
|
|
155
|
+
samples_row = "samples\t" + "\t".join(str(int(cm[i].sum())) for i in range(len(sorted_class_names))) + f"\t{int(cm.sum())}"
|
|
156
|
+
f.write(samples_row + "\n")
|
|
157
|
+
|
|
158
|
+
print(f"Saved TSV to {tsv_path}")
|
|
159
|
+
|
|
160
|
+
print(f"\nSaved detailed metrics to {report_path}")
|
|
161
|
+
print(f"\nBalanced Accuracy: {balanced_acc:.2f}%")
|
|
162
|
+
print(f"Macro F1: {macro_f1:.2f}%\n")
|
|
163
|
+
print(f"{'Class':<30} {'Accuracy':>10} {'F1':>10} {'Samples':>10}")
|
|
164
|
+
print("-" * 65)
|
|
165
|
+
for i, class_name in enumerate(sorted_class_names):
|
|
166
|
+
num_samples = int(cm[i].sum())
|
|
167
|
+
print(f"{class_name:<30} {per_class_accuracy[i]:>9.2f}% {per_class_f1[i]:>9.2f}% {num_samples:>10}")
|
|
168
|
+
print("-" * 65)
|
|
169
|
+
|
|
170
|
+
return {
|
|
171
|
+
"balanced_accuracy": balanced_acc,
|
|
172
|
+
"macro_f1": macro_f1,
|
|
173
|
+
"per_class_accuracy": {name: per_class_accuracy[i] for i, name in enumerate(sorted_class_names)},
|
|
174
|
+
"per_class_f1": {name: per_class_f1[i] for i, name in enumerate(sorted_class_names)},
|
|
175
|
+
}
|
|
176
|
+
|