opensportslib 0.0.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. opensportslib/__init__.py +18 -0
  2. opensportslib/apis/__init__.py +21 -0
  3. opensportslib/apis/classification.py +361 -0
  4. opensportslib/apis/localization.py +228 -0
  5. opensportslib/config/classification.yaml +104 -0
  6. opensportslib/config/classification_tracking.yaml +103 -0
  7. opensportslib/config/graph_tracking_classification/avgpool.yaml +79 -0
  8. opensportslib/config/graph_tracking_classification/gin.yaml +79 -0
  9. opensportslib/config/graph_tracking_classification/graphconv.yaml +79 -0
  10. opensportslib/config/graph_tracking_classification/graphsage.yaml +79 -0
  11. opensportslib/config/graph_tracking_classification/maxpool.yaml +79 -0
  12. opensportslib/config/graph_tracking_classification/noedges.yaml +79 -0
  13. opensportslib/config/localization.yaml +132 -0
  14. opensportslib/config/sngar_frames.yaml +98 -0
  15. opensportslib/core/__init__.py +0 -0
  16. opensportslib/core/loss/__init__.py +0 -0
  17. opensportslib/core/loss/builder.py +40 -0
  18. opensportslib/core/loss/calf.py +258 -0
  19. opensportslib/core/loss/ce.py +23 -0
  20. opensportslib/core/loss/combine.py +42 -0
  21. opensportslib/core/loss/nll.py +25 -0
  22. opensportslib/core/optimizer/__init__.py +0 -0
  23. opensportslib/core/optimizer/builder.py +38 -0
  24. opensportslib/core/sampler/weighted_sampler.py +104 -0
  25. opensportslib/core/scheduler/__init__.py +0 -0
  26. opensportslib/core/scheduler/builder.py +77 -0
  27. opensportslib/core/trainer/__init__.py +0 -0
  28. opensportslib/core/trainer/classification_trainer.py +1131 -0
  29. opensportslib/core/trainer/localization_trainer.py +1009 -0
  30. opensportslib/core/utils/checkpoint.py +238 -0
  31. opensportslib/core/utils/config.py +199 -0
  32. opensportslib/core/utils/data.py +85 -0
  33. opensportslib/core/utils/ddp.py +77 -0
  34. opensportslib/core/utils/default_args.py +110 -0
  35. opensportslib/core/utils/load_annotations.py +485 -0
  36. opensportslib/core/utils/seed.py +26 -0
  37. opensportslib/core/utils/video_processing.py +389 -0
  38. opensportslib/core/utils/wandb.py +110 -0
  39. opensportslib/datasets/__init__.py +0 -0
  40. opensportslib/datasets/builder.py +42 -0
  41. opensportslib/datasets/classification_dataset.py +582 -0
  42. opensportslib/datasets/localization_dataset.py +813 -0
  43. opensportslib/datasets/utils/__init__.py +15 -0
  44. opensportslib/datasets/utils/tracking.py +615 -0
  45. opensportslib/metrics/classification_metric.py +176 -0
  46. opensportslib/metrics/localization_metric.py +1482 -0
  47. opensportslib/models/__init__.py +0 -0
  48. opensportslib/models/backbones/builder.py +590 -0
  49. opensportslib/models/base/e2e.py +252 -0
  50. opensportslib/models/base/tracking.py +73 -0
  51. opensportslib/models/base/vars.py +29 -0
  52. opensportslib/models/base/video.py +130 -0
  53. opensportslib/models/base/video_mae.py +60 -0
  54. opensportslib/models/builder.py +43 -0
  55. opensportslib/models/heads/builder.py +266 -0
  56. opensportslib/models/neck/builder.py +210 -0
  57. opensportslib/models/utils/common.py +176 -0
  58. opensportslib/models/utils/impl/__init__.py +0 -0
  59. opensportslib/models/utils/impl/asformer.py +390 -0
  60. opensportslib/models/utils/impl/calf.py +74 -0
  61. opensportslib/models/utils/impl/gsm.py +112 -0
  62. opensportslib/models/utils/impl/gtad.py +347 -0
  63. opensportslib/models/utils/impl/tsm.py +123 -0
  64. opensportslib/models/utils/litebase.py +59 -0
  65. opensportslib/models/utils/modules.py +120 -0
  66. opensportslib/models/utils/shift.py +135 -0
  67. opensportslib/models/utils/utils.py +276 -0
  68. opensportslib-0.0.1.dev2.dist-info/METADATA +566 -0
  69. opensportslib-0.0.1.dev2.dist-info/RECORD +73 -0
  70. opensportslib-0.0.1.dev2.dist-info/WHEEL +5 -0
  71. opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE +661 -0
  72. opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE-COMMERCIAL +5 -0
  73. opensportslib-0.0.1.dev2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,176 @@
1
+ import numpy as np
2
+ from evaluate import load
3
+ from sklearn.metrics import balanced_accuracy_score
4
+
5
+ # Load HuggingFace metrics
6
+ accuracy_metric = load("accuracy")
7
+ f1_metric = load("f1")
8
+ precision_metric = load("precision")
9
+ recall_metric = load("recall")
10
+
11
+ def process_preds_labels(eval_pred, top_k=None):
12
+ """
13
+ Handles tuple logits, one-hot labels, and optionally returns top-k predictions.
14
+ """
15
+ logits, labels = eval_pred
16
+
17
+ # Handle tuple outputs (some HF models return tuple)
18
+ if isinstance(logits, tuple):
19
+ logits = logits[0]
20
+
21
+ logits = np.asarray(logits)
22
+ labels = np.asarray(labels)
23
+
24
+ # Convert one-hot labels to class indices
25
+ if labels.ndim > 1:
26
+ labels = np.argmax(labels, axis=-1)
27
+
28
+ # # Ensure logits are 2D for top-k computations
29
+ # if logits.ndim == 1:
30
+ # logits = logits.reshape(1, -1)
31
+
32
+ # Predicted classes
33
+ preds = np.argmax(logits, axis=-1)
34
+
35
+ # Top-k predictions for top-k accuracy
36
+ if top_k is not None:
37
+ topk_preds = np.argsort(logits, axis=-1)[:, -top_k:]
38
+ else:
39
+ topk_preds = None
40
+
41
+ return preds, labels, topk_preds
42
+
43
+ def compute_classification_metrics(eval_pred, top_k=None, mode="logits"):
44
+ """
45
+ Compute accuracy, F1, precision, recall, and optionally top-k accuracy.
46
+ Returns a dictionary for HF Trainer.
47
+ """
48
+
49
+ metrics = {}
50
+ if mode=="labels":
51
+ preds, labels = eval_pred
52
+ preds = np.array(preds)
53
+ labels = np.array(labels)
54
+ else:
55
+ preds, labels, topk_preds = process_preds_labels(eval_pred, top_k)
56
+
57
+ # Top-k accuracy
58
+ if top_k is not None and topk_preds is not None:
59
+ topk_correct = sum([labels[i] in topk_preds[i] for i in range(len(labels))])
60
+ metrics[f"top_{top_k}_accuracy"] = topk_correct / len(labels)
61
+
62
+ # Accuracy
63
+ metrics["accuracy"] = accuracy_metric.compute(predictions=preds, references=labels)["accuracy"]
64
+
65
+ # Balanced accuracy
66
+ metrics["balanced_accuracy"] = balanced_accuracy_score(labels, preds)
67
+
68
+ # F1 (macro)
69
+ metrics["f1"] = f1_metric.compute(predictions=preds, references=labels, average="macro")["f1"]
70
+
71
+ # Precision
72
+ metrics["precision"] = precision_metric.compute(predictions=preds, references=labels, average="macro")["precision"]
73
+
74
+ # Recall
75
+ metrics["recall"] = recall_metric.compute(predictions=preds, references=labels, average="macro")["recall"]
76
+
77
+
78
+ return metrics
79
+
80
+ def compute_detailed_classification_metrics(all_logits, all_labels, class_names, save_dir, set_name):
81
+ from sklearn.metrics import confusion_matrix, classification_report, balanced_accuracy_score, f1_score
82
+
83
+ preds = np.argmax(all_logits, axis=-1)
84
+
85
+ sorted_class_names = sorted(class_names.values())
86
+ name_to_sorted_idx = {name: i for i, name in enumerate(sorted_class_names)}
87
+ idx_to_name = class_names
88
+
89
+ sorted_labels = np.array([name_to_sorted_idx[idx_to_name[l]] for l in all_labels])
90
+ sorted_preds = np.array([name_to_sorted_idx[idx_to_name[p]] for p in preds])
91
+
92
+ all_class_labels = list(range(len(sorted_class_names)))
93
+
94
+ cm = confusion_matrix(sorted_labels, sorted_preds, labels=all_class_labels)
95
+ per_class_accuracy = np.diag(cm) / np.maximum(cm.sum(axis=1), 1) * 100
96
+ balanced_acc = balanced_accuracy_score(sorted_labels, sorted_preds) * 100
97
+ per_class_f1 = f1_score(sorted_labels, sorted_preds, labels=all_class_labels, average=None, zero_division=0) * 100
98
+ macro_f1 = f1_score(sorted_labels, sorted_preds, labels=all_class_labels, average="macro", zero_division=0) * 100
99
+
100
+ import matplotlib
101
+ matplotlib.use("Agg")
102
+ import matplotlib.pyplot as plt
103
+ import seaborn as sns
104
+ import os
105
+
106
+ plt.figure(figsize=(12, 10))
107
+ sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
108
+ xticklabels=sorted_class_names, yticklabels=sorted_class_names)
109
+ plt.title(f'Confusion Matrix ({set_name})')
110
+ plt.ylabel('True Label')
111
+ plt.xlabel('Predicted Label')
112
+ plt.xticks(rotation=45, ha='right')
113
+ plt.tight_layout()
114
+
115
+ plots_dir = os.path.join(save_dir, 'plots')
116
+ os.makedirs(plots_dir, exist_ok=True)
117
+ plt.savefig(os.path.join(plots_dir, f'confusion_matrix_{set_name}.png'), dpi=300, bbox_inches='tight')
118
+ plt.close()
119
+
120
+ results_dir = os.path.join(save_dir, 'results')
121
+ os.makedirs(results_dir, exist_ok=True)
122
+
123
+ report_path = os.path.join(results_dir, f'{set_name}_detailed_metrics.txt')
124
+ with open(report_path, 'w') as f:
125
+ f.write(f"Balanced Accuracy: {balanced_acc:.2f}%\n")
126
+ f.write(f"Macro F1: {macro_f1:.2f}%\n\n")
127
+ f.write(f"{'Class':<30} {'Accuracy':>10} {'F1':>10} {'Samples':>10}\n")
128
+ f.write("-" * 65 + "\n")
129
+ for i, class_name in enumerate(sorted_class_names):
130
+ num_samples = int(cm[i].sum())
131
+ f.write(f"{class_name:<30} {per_class_accuracy[i]:>9.2f}% {per_class_f1[i]:>9.2f}% {num_samples:>10}\n")
132
+ f.write("-" * 65 + "\n\n")
133
+ f.write("Classification Report:\n\n")
134
+ f.write(classification_report(
135
+ sorted_labels, sorted_preds,
136
+ labels=all_class_labels,
137
+ target_names=sorted_class_names,
138
+ zero_division=0
139
+ ))
140
+ f.write("\n" + "-" * 65 + "\n\n")
141
+ f.write("Confusion Matrix:\n\n")
142
+ f.write(f"{cm}\n")
143
+
144
+ tsv_path = os.path.join(results_dir, f'{set_name}_results.tsv')
145
+ with open(tsv_path, 'w') as f:
146
+ header = "metric\t" + "\t".join(sorted_class_names) + "\toverall"
147
+ f.write(header + "\n")
148
+
149
+ acc_row = "accuracy\t" + "\t".join(f"{per_class_accuracy[i]:.2f}" for i in range(len(sorted_class_names))) + f"\t{balanced_acc:.2f}"
150
+ f.write(acc_row + "\n")
151
+
152
+ f1_row = "f1\t" + "\t".join(f"{per_class_f1[i]:.2f}" for i in range(len(sorted_class_names))) + f"\t{macro_f1:.2f}"
153
+ f.write(f1_row + "\n")
154
+
155
+ samples_row = "samples\t" + "\t".join(str(int(cm[i].sum())) for i in range(len(sorted_class_names))) + f"\t{int(cm.sum())}"
156
+ f.write(samples_row + "\n")
157
+
158
+ print(f"Saved TSV to {tsv_path}")
159
+
160
+ print(f"\nSaved detailed metrics to {report_path}")
161
+ print(f"\nBalanced Accuracy: {balanced_acc:.2f}%")
162
+ print(f"Macro F1: {macro_f1:.2f}%\n")
163
+ print(f"{'Class':<30} {'Accuracy':>10} {'F1':>10} {'Samples':>10}")
164
+ print("-" * 65)
165
+ for i, class_name in enumerate(sorted_class_names):
166
+ num_samples = int(cm[i].sum())
167
+ print(f"{class_name:<30} {per_class_accuracy[i]:>9.2f}% {per_class_f1[i]:>9.2f}% {num_samples:>10}")
168
+ print("-" * 65)
169
+
170
+ return {
171
+ "balanced_accuracy": balanced_acc,
172
+ "macro_f1": macro_f1,
173
+ "per_class_accuracy": {name: per_class_accuracy[i] for i, name in enumerate(sorted_class_names)},
174
+ "per_class_f1": {name: per_class_f1[i] for i, name in enumerate(sorted_class_names)},
175
+ }
176
+