wunderscout 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wunderscout/__init__.py +6 -0
- wunderscout/core.py +164 -0
- wunderscout/exporters.py +43 -0
- wunderscout/geometry.py +74 -0
- wunderscout/heatmap.py +115 -0
- wunderscout/main.py +598 -0
- wunderscout/pass_network.py +103 -0
- wunderscout/teams.py +76 -0
- wunderscout/vision.py +155 -0
- wunderscout-0.1.2.dist-info/METADATA +27 -0
- wunderscout-0.1.2.dist-info/RECORD +12 -0
- wunderscout-0.1.2.dist-info/WHEEL +4 -0
wunderscout/__init__.py
ADDED
wunderscout/core.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
import cv2
|
|
2
|
+
import supervision as sv
|
|
3
|
+
import numpy as np
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from .vision import VisionEngine
|
|
6
|
+
from .geometry import PitchMapper
|
|
7
|
+
from .teams import TeamClassifier
|
|
8
|
+
from .exporters import DataExporter
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ScoutingPipeline:
|
|
12
|
+
def __init__(self, player_weights, field_weights):
|
|
13
|
+
self.engine = VisionEngine(player_weights, field_weights)
|
|
14
|
+
self.mapper = PitchMapper()
|
|
15
|
+
self.classifier = TeamClassifier()
|
|
16
|
+
|
|
17
|
+
def run(self, video_path, output_video_path):
|
|
18
|
+
# 1. Warm-up (Calibration)
|
|
19
|
+
print("WORKER: Calibrating teams...")
|
|
20
|
+
crops = self.engine.get_calibration_crops(video_path)
|
|
21
|
+
if len(crops) > 0:
|
|
22
|
+
embeddings = self.engine.get_embeddings(crops)
|
|
23
|
+
self.classifier.fit(embeddings)
|
|
24
|
+
else:
|
|
25
|
+
print("WARNING: No player crops found for calibration.")
|
|
26
|
+
|
|
27
|
+
# 2. Setup Video I/O
|
|
28
|
+
output_path_obj = Path(output_video_path)
|
|
29
|
+
output_path_obj.parent.mkdir(parents=True, exist_ok=True)
|
|
30
|
+
# ---------------------------------------------------------------------
|
|
31
|
+
|
|
32
|
+
cap = cv2.VideoCapture(video_path)
|
|
33
|
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
|
34
|
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
35
|
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
36
|
+
|
|
37
|
+
out = cv2.VideoWriter(
|
|
38
|
+
output_video_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height)
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
if not out.isOpened():
|
|
42
|
+
print(f"ERROR: Could not create video file at {output_video_path}")
|
|
43
|
+
return
|
|
44
|
+
|
|
45
|
+
tracker = sv.ByteTrack()
|
|
46
|
+
tracking_results = {}
|
|
47
|
+
|
|
48
|
+
# ID Constants
|
|
49
|
+
BALL_ID = 0
|
|
50
|
+
GOALKEEPER_ID = 1
|
|
51
|
+
PLAYER_ID = 2
|
|
52
|
+
REFEREE_ID = 3
|
|
53
|
+
|
|
54
|
+
# 3. Main Processing Loop
|
|
55
|
+
print(f"WORKER: Starting processing: {video_path}")
|
|
56
|
+
frame_generator = sv.get_video_frames_generator(video_path)
|
|
57
|
+
|
|
58
|
+
for frame_idx, frame in enumerate(frame_generator):
|
|
59
|
+
if frame_idx % 100 == 0:
|
|
60
|
+
print(f"WORKER: Processing frame {frame_idx}")
|
|
61
|
+
|
|
62
|
+
# --- A. DETECTION ---
|
|
63
|
+
all_dets = self.engine.detect_players(frame)
|
|
64
|
+
f_res = self.engine.detect_field(frame)
|
|
65
|
+
|
|
66
|
+
# --- B. FIELD HOMOGRAPHY ---
|
|
67
|
+
H = None
|
|
68
|
+
if f_res.keypoints is not None and len(f_res.keypoints.xy) > 0:
|
|
69
|
+
H = self.mapper.get_matrix(
|
|
70
|
+
f_res.keypoints.xy[0].cpu().numpy(),
|
|
71
|
+
f_res.keypoints.conf[0].cpu().numpy(),
|
|
72
|
+
)
|
|
73
|
+
else:
|
|
74
|
+
H = self.mapper.last_h
|
|
75
|
+
|
|
76
|
+
# --- C. SEPARATE BALL & OTHERS ---
|
|
77
|
+
ball_detections = all_dets[all_dets.class_id == BALL_ID]
|
|
78
|
+
ball_detections.xyxy = sv.pad_boxes(xyxy=ball_detections.xyxy, px=10)
|
|
79
|
+
|
|
80
|
+
other_detections = all_dets[all_dets.class_id != BALL_ID]
|
|
81
|
+
other_detections = other_detections.with_nms(threshold=0.5)
|
|
82
|
+
|
|
83
|
+
# --- D. TRACKING ---
|
|
84
|
+
tracked_objects = tracker.update_with_detections(other_detections)
|
|
85
|
+
|
|
86
|
+
# Split tracked objects
|
|
87
|
+
tracked_players = tracked_objects[tracked_objects.class_id == PLAYER_ID]
|
|
88
|
+
tracked_gks = tracked_objects[tracked_objects.class_id == GOALKEEPER_ID]
|
|
89
|
+
tracked_refs = tracked_objects[tracked_objects.class_id == REFEREE_ID]
|
|
90
|
+
|
|
91
|
+
# --- E. TEAM CLASSIFICATION ---
|
|
92
|
+
|
|
93
|
+
# 1. Players
|
|
94
|
+
if len(tracked_players) > 0:
|
|
95
|
+
p_crops = [sv.crop_image(frame, xyxy) for xyxy in tracked_players.xyxy]
|
|
96
|
+
p_pil = [sv.cv2_to_pillow(c) for c in p_crops]
|
|
97
|
+
p_embeddings = self.engine.get_embeddings(p_pil)
|
|
98
|
+
|
|
99
|
+
final_team_ids = []
|
|
100
|
+
for i, tid in enumerate(tracked_players.tracker_id):
|
|
101
|
+
team_id = self.classifier.get_consensus_team(tid, p_embeddings[i])
|
|
102
|
+
final_team_ids.append(team_id)
|
|
103
|
+
|
|
104
|
+
tracked_players.class_id = np.array(final_team_ids)
|
|
105
|
+
|
|
106
|
+
# 2. Goalkeepers
|
|
107
|
+
if len(tracked_gks) > 0 and len(tracked_players) > 0:
|
|
108
|
+
tracked_gks.class_id = self.classifier.resolve_goalkeepers_team_id(
|
|
109
|
+
tracked_players, tracked_gks
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# 3. Referees (Shift ID 3 -> 2)
|
|
113
|
+
if len(tracked_refs) > 0:
|
|
114
|
+
tracked_refs.class_id -= 1
|
|
115
|
+
|
|
116
|
+
# --- F. DATA STORAGE ---
|
|
117
|
+
tracking_results[frame_idx] = {"players": {}, "ball": None}
|
|
118
|
+
data_targets = sv.Detections.merge([tracked_players, tracked_gks])
|
|
119
|
+
|
|
120
|
+
if H is not None:
|
|
121
|
+
if len(data_targets) > 0:
|
|
122
|
+
feet_coords = data_targets.get_anchors_coordinates(
|
|
123
|
+
sv.Position.BOTTOM_CENTER
|
|
124
|
+
)
|
|
125
|
+
transformed_feet = self.mapper.transform(feet_coords, H)
|
|
126
|
+
|
|
127
|
+
for i, tid in enumerate(data_targets.tracker_id):
|
|
128
|
+
px, py = transformed_feet[i]
|
|
129
|
+
tracking_results[frame_idx]["players"][tid] = (
|
|
130
|
+
max(0.0, min(1.0, px)),
|
|
131
|
+
max(0.0, min(1.0, py)),
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
if len(ball_detections) > 0:
|
|
135
|
+
ball_coords = ball_detections.get_anchors_coordinates(
|
|
136
|
+
sv.Position.CENTER
|
|
137
|
+
)
|
|
138
|
+
transformed_ball = self.mapper.transform([ball_coords[0]], H)
|
|
139
|
+
bx, by = transformed_ball[0]
|
|
140
|
+
tracking_results[frame_idx]["ball"] = (
|
|
141
|
+
max(0.0, min(1.0, bx)),
|
|
142
|
+
max(0.0, min(1.0, by)),
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# --- G. DRAW & WRITE VIDEO ---
|
|
146
|
+
all_tracked = sv.Detections.merge(
|
|
147
|
+
[tracked_players, tracked_gks, tracked_refs]
|
|
148
|
+
)
|
|
149
|
+
annotated_frame = self.engine.draw_annotations(
|
|
150
|
+
frame, all_tracked, ball_detections
|
|
151
|
+
)
|
|
152
|
+
out.write(annotated_frame)
|
|
153
|
+
|
|
154
|
+
# 4. Cleanup
|
|
155
|
+
out.release()
|
|
156
|
+
cap.release()
|
|
157
|
+
print(f"WORKER: Video saved to {output_video_path}")
|
|
158
|
+
|
|
159
|
+
# Save CSVs
|
|
160
|
+
final_assignments = self.classifier.get_final_assignments()
|
|
161
|
+
csv_path = output_video_path.replace(".mp4", ".csv")
|
|
162
|
+
DataExporter.save_csvs(
|
|
163
|
+
tracking_results, final_assignments, frame_idx, fps, csv_path
|
|
164
|
+
)
|
wunderscout/exporters.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
import csv
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class DataExporter:
|
|
6
|
+
@staticmethod
|
|
7
|
+
def save_csvs(tracking_data, team_assignments, total_frames, fps, output_path):
|
|
8
|
+
"""
|
|
9
|
+
tracking_data: {frame_idx: {"ball": (x,y), "players": {id: (x,y)}}}
|
|
10
|
+
team_assignments: {tracker_id: team_id}
|
|
11
|
+
"""
|
|
12
|
+
path_obj = Path(output_path)
|
|
13
|
+
path_obj.parent.mkdir(parents=True, exist_ok=True)
|
|
14
|
+
base_name = str(path_obj.with_suffix(""))
|
|
15
|
+
home_ids = [tid for tid, team in team_assignments.items() if team == 0]
|
|
16
|
+
away_ids = [tid for tid, team in team_assignments.items() if team == 1]
|
|
17
|
+
|
|
18
|
+
def write_file(filename, team_name, ids):
|
|
19
|
+
with open(filename, "w", newline="") as f:
|
|
20
|
+
writer = csv.writer(f)
|
|
21
|
+
writer.writerow(
|
|
22
|
+
["", "", ""] + [team_name for _ in ids for _ in (0, 1)] + ["", ""]
|
|
23
|
+
)
|
|
24
|
+
writer.writerow(
|
|
25
|
+
["", "", ""] + [str(pid) for pid in ids for _ in (0, 1)] + ["", ""]
|
|
26
|
+
)
|
|
27
|
+
writer.writerow(
|
|
28
|
+
["Period", "Frame", "Time [s]"]
|
|
29
|
+
+ [f"Player{pid}_{axis}" for pid in ids for axis in ("X", "Y")]
|
|
30
|
+
+ ["Ball_X", "Ball_Y"]
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
for f_idx in range(total_frames):
|
|
34
|
+
data = tracking_data.get(f_idx, {"ball": None, "players": {}})
|
|
35
|
+
row = [1, f_idx, f"{f_idx / fps:.2f}"]
|
|
36
|
+
for tid in ids:
|
|
37
|
+
coords = data["players"].get(tid, ("NaN", "NaN"))
|
|
38
|
+
row.extend(coords)
|
|
39
|
+
row.extend(data["ball"] if data["ball"] else ("NaN", "NaN"))
|
|
40
|
+
writer.writerow(row)
|
|
41
|
+
|
|
42
|
+
write_file(f"{base_name}_Home.csv", "Home", sorted(home_ids))
|
|
43
|
+
write_file(f"{base_name}_Away.csv", "Away", sorted(away_ids))
|
wunderscout/geometry.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
import cv2
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
PITCH_CONFIG = {
|
|
5
|
+
# --- LEFT GOAL LINE ---
|
|
6
|
+
0: (0.000, 0.000), # Top-Left Corner
|
|
7
|
+
1: (0.000, 0.204), # Top Edge of Penalty Box
|
|
8
|
+
2: (0.000, 0.365), # Top Edge of Goal Area
|
|
9
|
+
3: (0.000, 0.635), # Bottom Edge of Goal Area
|
|
10
|
+
4: (0.000, 0.796), # Bottom Edge of Penalty Box
|
|
11
|
+
5: (0.000, 1.000), # Bottom-Left Corner
|
|
12
|
+
# --- LEFT PENALTY AREA ---
|
|
13
|
+
6: (0.052, 0.365),
|
|
14
|
+
7: (0.052, 0.635),
|
|
15
|
+
8: (0.105, 0.500), # Penalty Spot (Left)
|
|
16
|
+
9: (0.157, 0.204),
|
|
17
|
+
10: (0.157, 0.392),
|
|
18
|
+
11: (0.157, 0.608),
|
|
19
|
+
12: (0.157, 0.796),
|
|
20
|
+
# --- MIDFIELD ---
|
|
21
|
+
13: (0.413, 0.500),
|
|
22
|
+
14: (0.500, 0.000),
|
|
23
|
+
15: (0.500, 0.365),
|
|
24
|
+
16: (0.500, 0.635),
|
|
25
|
+
17: (0.500, 1.000),
|
|
26
|
+
18: (0.587, 0.500),
|
|
27
|
+
# --- RIGHT PENALTY AREA ---
|
|
28
|
+
19: (0.843, 0.204),
|
|
29
|
+
20: (0.843, 0.392),
|
|
30
|
+
21: (0.843, 0.608),
|
|
31
|
+
22: (0.843, 0.796),
|
|
32
|
+
23: (0.895, 0.500), # Penalty Spot (Right)
|
|
33
|
+
24: (0.948, 0.365),
|
|
34
|
+
25: (0.948, 0.635),
|
|
35
|
+
# --- RIGHT GOAL LINE ---
|
|
36
|
+
26: (1.000, 0.000),
|
|
37
|
+
27: (1.000, 0.204),
|
|
38
|
+
28: (1.000, 0.365),
|
|
39
|
+
29: (1.000, 0.635),
|
|
40
|
+
30: (1.000, 0.796),
|
|
41
|
+
31: (1.000, 1.000),
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class PitchMapper:
|
|
46
|
+
def __init__(self, pitch_config=PITCH_CONFIG):
|
|
47
|
+
self.pitch_config = pitch_config
|
|
48
|
+
self.last_h = None
|
|
49
|
+
|
|
50
|
+
def get_matrix(self, keypoints_xy, keypoints_conf):
|
|
51
|
+
src_points = []
|
|
52
|
+
dst_points = []
|
|
53
|
+
|
|
54
|
+
for i, (xy, conf) in enumerate(zip(keypoints_xy, keypoints_conf)):
|
|
55
|
+
if conf > 0.5 and i in self.pitch_config:
|
|
56
|
+
src_points.append(xy)
|
|
57
|
+
dst_points.append(self.pitch_config[i])
|
|
58
|
+
|
|
59
|
+
if len(src_points) >= 4:
|
|
60
|
+
H, _ = cv2.findHomography(
|
|
61
|
+
np.array(src_points), np.array(dst_points), cv2.RANSAC
|
|
62
|
+
)
|
|
63
|
+
self.last_h = H
|
|
64
|
+
|
|
65
|
+
return self.last_h
|
|
66
|
+
|
|
67
|
+
def transform(self, points, H=None):
|
|
68
|
+
target_h = H if H is not None else self.last_h
|
|
69
|
+
if target_h is None or len(points) == 0:
|
|
70
|
+
return []
|
|
71
|
+
|
|
72
|
+
points_reshaped = np.array(points).reshape(-1, 1, 2).astype(np.float32)
|
|
73
|
+
projected = cv2.perspectiveTransform(points_reshaped, target_h)
|
|
74
|
+
return projected.reshape(-1, 2)
|
wunderscout/heatmap.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
import matplotlib.pyplot as plt
|
|
3
|
+
import numpy as np
|
|
4
|
+
import seaborn as sns
|
|
5
|
+
import json
|
|
6
|
+
from scipy.stats import gaussian_kde
|
|
7
|
+
|
|
8
|
+
# === Load the raw tracking data CSV ===
|
|
9
|
+
# NOTE: "header=2" → skip the first two rows (team/labels) and use row 3 as header
|
|
10
|
+
df = pd.read_csv(
|
|
11
|
+
"./data/Sample_Game_1_RawTrackingData_Away_Team.csv",
|
|
12
|
+
header=2,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
# === Clean column names so each player has _X and _Y ===
|
|
16
|
+
cleaned_colums = []
|
|
17
|
+
colnames = df.columns.tolist()
|
|
18
|
+
i = 0
|
|
19
|
+
while i < len(colnames):
|
|
20
|
+
col = colnames[i]
|
|
21
|
+
if col.startswith("Player") or col.startswith("Ball"):
|
|
22
|
+
cleaned_colums.append(f"{col}_X")
|
|
23
|
+
cleaned_colums.append(f"{col}_Y")
|
|
24
|
+
i += 2
|
|
25
|
+
else:
|
|
26
|
+
cleaned_colums.append(col)
|
|
27
|
+
i += 1
|
|
28
|
+
df.columns = cleaned_colums
|
|
29
|
+
|
|
30
|
+
print("Columns cleaned. First few rows:")
|
|
31
|
+
print(df.head())
|
|
32
|
+
|
|
33
|
+
# === Extract Player17 (drop NaN values where tracking failed) ===
|
|
34
|
+
player17 = df[["Player17_X", "Player17_Y"]].dropna()
|
|
35
|
+
x = player17["Player17_X"].to_numpy()
|
|
36
|
+
y = player17["Player17_Y"].to_numpy()
|
|
37
|
+
|
|
38
|
+
# === Detect scale (normalized [0,1] or real meters) ===
|
|
39
|
+
if x.max() <= 1.5 and y.max() <= 1.5:
|
|
40
|
+
print("Scaling Player17 data from normalized [0,1] to meters...")
|
|
41
|
+
x = x * 105 # pitch length in meters
|
|
42
|
+
y = y * 68 # pitch width in meters
|
|
43
|
+
else:
|
|
44
|
+
print("Data appears to already be in meters, leaving as is.")
|
|
45
|
+
|
|
46
|
+
print("First 10 points:", list(zip(x[:10], y[:10])))
|
|
47
|
+
|
|
48
|
+
# =============================================================================
|
|
49
|
+
# 1. Scatter Plot (sanity check, raw positions)
|
|
50
|
+
# =============================================================================
|
|
51
|
+
fig, ax = plt.subplots(figsize=(10, 7))
|
|
52
|
+
# Pitch outline
|
|
53
|
+
ax.plot([0, 105, 105, 0, 0], [0, 0, 68, 68, 0], color="black")
|
|
54
|
+
ax.plot([52.5, 52.5], [0, 68], color="black") # halfway line
|
|
55
|
+
# Player positions
|
|
56
|
+
ax.scatter(x, y, s=1, alpha=0.3, color="blue")
|
|
57
|
+
ax.set_xlim(0, 105)
|
|
58
|
+
ax.set_ylim(0, 68)
|
|
59
|
+
ax.set_title("Player17 Movement Scatter (raw positions)")
|
|
60
|
+
plt.savefig("./heatmap/player17_scatter.png", dpi=150, bbox_inches="tight")
|
|
61
|
+
|
|
62
|
+
# =============================================================================
|
|
63
|
+
# 2. Histogram Heatmap (occupancy grid)
|
|
64
|
+
# =============================================================================
|
|
65
|
+
heatmap, xedges, yedges = np.histogram2d(x, y, bins=(50, 34), range=[[0, 105], [0, 68]])
|
|
66
|
+
|
|
67
|
+
fig, ax = plt.subplots(figsize=(10, 7))
|
|
68
|
+
extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
|
|
69
|
+
im = ax.imshow(
|
|
70
|
+
heatmap.T, origin="lower", extent=extent, cmap="Blues", alpha=0.7, aspect="auto"
|
|
71
|
+
)
|
|
72
|
+
ax.plot([0, 105, 105, 0, 0], [0, 0, 68, 68, 0], color="black")
|
|
73
|
+
ax.plot([52.5, 52.5], [0, 68], color="black")
|
|
74
|
+
fig.colorbar(im, ax=ax, label="Frames")
|
|
75
|
+
ax.set_title("Player17 Heatmap (Histogram)")
|
|
76
|
+
plt.savefig("./heatmap/player17_histogram.png", dpi=150, bbox_inches="tight")
|
|
77
|
+
|
|
78
|
+
# === Export histogram data as JSON for three.js ===
|
|
79
|
+
heatmap_data = {
|
|
80
|
+
"xedges": xedges.tolist(),
|
|
81
|
+
"yedges": yedges.tolist(),
|
|
82
|
+
"values": heatmap.T.tolist(), # transpose so rows correspond to y-axis correctly
|
|
83
|
+
}
|
|
84
|
+
with open("./heatmap/player17_histogram.json", "w") as f:
|
|
85
|
+
json.dump(heatmap_data, f)
|
|
86
|
+
|
|
87
|
+
# =============================================================================
|
|
88
|
+
# 3. KDE Heatmap (smoothed density field)
|
|
89
|
+
# =============================================================================
|
|
90
|
+
values = np.vstack([x, y])
|
|
91
|
+
kde = gaussian_kde(values)
|
|
92
|
+
|
|
93
|
+
# Define mesh grid
|
|
94
|
+
X, Y = np.meshgrid(np.linspace(0, 105, 100), np.linspace(0, 68, 68))
|
|
95
|
+
Z = kde(np.vstack([X.ravel(), Y.ravel()])).reshape(X.shape)
|
|
96
|
+
|
|
97
|
+
fig, ax = plt.subplots(figsize=(10, 7))
|
|
98
|
+
sns.kdeplot(x=x, y=y, fill=True, cmap="Blues", alpha=0.7, thresh=0.05, levels=50, ax=ax)
|
|
99
|
+
ax.plot([0, 105, 105, 0, 0], [0, 0, 68, 68, 0], color="black")
|
|
100
|
+
ax.plot([52.5, 52.5], [0, 68], color="black")
|
|
101
|
+
ax.set_xlim(0, 105)
|
|
102
|
+
ax.set_ylim(0, 68)
|
|
103
|
+
ax.set_title("Player17 Heatmap (KDE Smoothed)")
|
|
104
|
+
plt.savefig("./heatmap/player17_kde.png", dpi=150, bbox_inches="tight")
|
|
105
|
+
|
|
106
|
+
# === Export KDE density field for three.js ===
|
|
107
|
+
kde_data = {
|
|
108
|
+
"x": X[0].tolist(), # x grid coordinates
|
|
109
|
+
"y": Y[:, 0].tolist(), # y grid coordinates
|
|
110
|
+
"values": Z.tolist(), # density values
|
|
111
|
+
}
|
|
112
|
+
with open("./heatmap/player17_kde.json", "w") as f:
|
|
113
|
+
json.dump(kde_data, f)
|
|
114
|
+
|
|
115
|
+
print("Outputs saved: scatter, histogram PNG+JSON, KDE PNG+JSON for Player17")
|
wunderscout/main.py
ADDED
|
@@ -0,0 +1,598 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from dotenv import load_dotenv
|
|
3
|
+
from roboflow import Roboflow
|
|
4
|
+
import torch
|
|
5
|
+
from ultralytics import YOLO
|
|
6
|
+
import supervision as sv
|
|
7
|
+
import cv2
|
|
8
|
+
from tqdm import tqdm
|
|
9
|
+
from transformers import AutoProcessor, SiglipVisionModel
|
|
10
|
+
from more_itertools import chunked
|
|
11
|
+
import numpy as np
|
|
12
|
+
import umap
|
|
13
|
+
from sklearn.cluster import KMeans
|
|
14
|
+
import plotly.graph_objects as go
|
|
15
|
+
import base64
|
|
16
|
+
from io import BytesIO
|
|
17
|
+
|
|
18
|
+
load_dotenv()
|
|
19
|
+
|
|
20
|
+
SIGLIP_MODEL_PATH = "google/siglip-base-patch16-224"
|
|
21
|
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
22
|
+
EMBEDDINGS_MODEL = SiglipVisionModel.from_pretrained(SIGLIP_MODEL_PATH).to(DEVICE)
|
|
23
|
+
EMBEDDINGS_PROCESSOR = AutoProcessor.from_pretrained(SIGLIP_MODEL_PATH)
|
|
24
|
+
ROBOFLOW_API_KEY = os.getenv("ROBOFLOW_API_KEY")
|
|
25
|
+
|
|
26
|
+
FIELD_DETECTION_MODEL_ID = "football-field-detection-f07vi/14"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def train():
|
|
30
|
+
print("Beginning training ...")
|
|
31
|
+
# This is to get the roboflow dataset and train.
|
|
32
|
+
rf = Roboflow(api_key=ROBOFLOW_API_KEY)
|
|
33
|
+
|
|
34
|
+
project = rf.workspace("roboflow-jvuqo").project("football-players-detection-3zvbc")
|
|
35
|
+
version = project.version(20)
|
|
36
|
+
dataset = version.download("yolov11")
|
|
37
|
+
|
|
38
|
+
model_base_s = YOLO("yolo11s.pt")
|
|
39
|
+
|
|
40
|
+
results = model_base_s.train(
|
|
41
|
+
data="./football-players-detection-20/data.yaml",
|
|
42
|
+
save=True,
|
|
43
|
+
epochs=50,
|
|
44
|
+
imgsz=1280,
|
|
45
|
+
plots=True,
|
|
46
|
+
device=0,
|
|
47
|
+
batch=6,
|
|
48
|
+
project="runs/detect",
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def inference(video_path):
|
|
53
|
+
print("Inference 3 ...")
|
|
54
|
+
# We do another inference with different annotations
|
|
55
|
+
BALL_ID = 0
|
|
56
|
+
|
|
57
|
+
model_trained = YOLO("./runs/detect/train/weights/best.pt")
|
|
58
|
+
ellipse_annotator = sv.EllipseAnnotator(
|
|
59
|
+
color=sv.ColorPalette.from_hex(["#00BFFF", "#FF1493", "FFD700"]), thickness=2
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
triangle_annotator = sv.TriangleAnnotator(
|
|
63
|
+
color=sv.Color.from_hex("#FFD700"), base=25, height=21, outline_thickness=1
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
frame_generator_2 = sv.get_video_frames_generator(video_path)
|
|
67
|
+
frame_2 = next(frame_generator_2)
|
|
68
|
+
|
|
69
|
+
result_3 = model_trained.predict(
|
|
70
|
+
frame_2, save=True, project="runs", name="inference"
|
|
71
|
+
)[0]
|
|
72
|
+
detections_2 = sv.Detections.from_ultralytics(result_3)
|
|
73
|
+
|
|
74
|
+
ball_detections = detections_2[detections_2.class_id == BALL_ID]
|
|
75
|
+
ball_detections.xyxy = sv.pad_boxes(xyxy=ball_detections.xyxy, px=10)
|
|
76
|
+
|
|
77
|
+
all_detections = detections_2[detections_2.class_id != BALL_ID]
|
|
78
|
+
all_detections = all_detections.with_nms(threshold=0.5, class_agnostic=True)
|
|
79
|
+
all_detections.class_id -= 1
|
|
80
|
+
|
|
81
|
+
annotated_frame_2 = frame_2.copy()
|
|
82
|
+
annotated_frame_2 = ellipse_annotator.annotate(
|
|
83
|
+
scene=annotated_frame_2, detections=all_detections
|
|
84
|
+
)
|
|
85
|
+
annotated_frame_2 = triangle_annotator.annotate(
|
|
86
|
+
scene=annotated_frame_2, detections=ball_detections
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
cv2.imwrite("./runs/annotated_frame.jpg", annotated_frame_2)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def inference_with_player_tracking(video_path):
|
|
93
|
+
print("Inference with tracking ...")
|
|
94
|
+
# We will do inference again but this time with the video.
|
|
95
|
+
BALL_ID = 0
|
|
96
|
+
|
|
97
|
+
model_trained = YOLO("./runs/detect/train/weights/best.pt")
|
|
98
|
+
ellipse_annotator_2 = sv.EllipseAnnotator(
|
|
99
|
+
color=sv.ColorPalette.from_hex(["#00BFFF", "#FF1493", "#FFD700"]), thickness=2
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
label_annotator_2 = sv.LabelAnnotator(
|
|
103
|
+
color=sv.ColorPalette.from_hex(["#00BFFF", "#FF1493", "#FFD700"]),
|
|
104
|
+
text_color=sv.Color.from_hex("#000000"),
|
|
105
|
+
text_position=sv.Position.BOTTOM_CENTER,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
triangle_annotator_2 = sv.TriangleAnnotator(
|
|
109
|
+
color=sv.Color.from_hex("#FFD700"),
|
|
110
|
+
base=25,
|
|
111
|
+
height=21,
|
|
112
|
+
outline_thickness=1,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
tracker = sv.ByteTrack()
|
|
116
|
+
|
|
117
|
+
frame_generator_3 = sv.get_video_frames_generator(video_path)
|
|
118
|
+
|
|
119
|
+
cap = cv2.VideoCapture(video_path)
|
|
120
|
+
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
121
|
+
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
122
|
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
|
123
|
+
|
|
124
|
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
|
125
|
+
out = cv2.VideoWriter(
|
|
126
|
+
"./runs/video_tracked.mp4", fourcc, fps, (frame_width, frame_height)
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
for frame in frame_generator_3:
|
|
130
|
+
result_tracking = model_trained.predict(frame, conf=0.3)[0]
|
|
131
|
+
|
|
132
|
+
detections = sv.Detections.from_ultralytics(result_tracking)
|
|
133
|
+
|
|
134
|
+
ball_detections = detections[detections.class_id == BALL_ID]
|
|
135
|
+
ball_detections.xyxy = sv.pad_boxes(xyxy=ball_detections.xyxy, px=10)
|
|
136
|
+
|
|
137
|
+
all_detections = detections[detections.class_id != BALL_ID]
|
|
138
|
+
all_detections = all_detections.with_nms(threshold=0.5, class_agnostic=True)
|
|
139
|
+
all_detections.class_id -= 1
|
|
140
|
+
all_detections = tracker.update_with_detections(detections=all_detections)
|
|
141
|
+
|
|
142
|
+
labels = [f"#{tracker_id}" for tracker_id in all_detections.tracker_id]
|
|
143
|
+
|
|
144
|
+
annotated_frame = frame.copy()
|
|
145
|
+
annotated_frame = ellipse_annotator_2.annotate(
|
|
146
|
+
scene=annotated_frame, detections=all_detections
|
|
147
|
+
)
|
|
148
|
+
annotated_frame = label_annotator_2.annotate(
|
|
149
|
+
scene=annotated_frame, detections=all_detections, labels=labels
|
|
150
|
+
)
|
|
151
|
+
annotated_frame = triangle_annotator_2.annotate(
|
|
152
|
+
scene=annotated_frame, detections=ball_detections
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
out.write(annotated_frame)
|
|
156
|
+
out.release()
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def create_player_crops(video_path):
|
|
160
|
+
print("Generating player crops ...")
|
|
161
|
+
# Getting training data for cluster model
|
|
162
|
+
PLAYER_ID = 2
|
|
163
|
+
STRIDE = 30
|
|
164
|
+
|
|
165
|
+
model_trained = YOLO("./runs/detect/train/weights/best.pt")
|
|
166
|
+
frame_generator_4 = sv.get_video_frames_generator(
|
|
167
|
+
source_path=video_path, stride=STRIDE
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
crops = []
|
|
171
|
+
for frame in tqdm(frame_generator_4, desc="collecting crops"):
|
|
172
|
+
result = model_trained.predict(frame, conf=0.3)[0]
|
|
173
|
+
detections = sv.Detections.from_ultralytics(result)
|
|
174
|
+
detections = detections.with_nms(threshold=0.5, class_agnostic=True)
|
|
175
|
+
detections = detections[detections.class_id == PLAYER_ID]
|
|
176
|
+
players_crops = [sv.crop_image(frame, xyxy) for xyxy in detections.xyxy]
|
|
177
|
+
crops += players_crops
|
|
178
|
+
|
|
179
|
+
print(f"Total crops collected: {len(crops)}.")
|
|
180
|
+
|
|
181
|
+
# Save to a file
|
|
182
|
+
# out_dir = "./runs/player_crops"
|
|
183
|
+
# os.makedirs(out_dir, exist_ok=True)
|
|
184
|
+
|
|
185
|
+
# for i, crop in enumerate(crops):
|
|
186
|
+
# filename = os.path.join(out_dir, f"crop_{i:04d}.jpg")
|
|
187
|
+
# cv2.imwrite(filename, crop)
|
|
188
|
+
|
|
189
|
+
# Convert to PIL
|
|
190
|
+
pil_crops = [sv.cv2_to_pillow(c) for c in crops]
|
|
191
|
+
return pil_crops
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def extract_embeddings(crops):
|
|
195
|
+
BATCH_SIZE = 32
|
|
196
|
+
|
|
197
|
+
batches = chunked(crops, BATCH_SIZE)
|
|
198
|
+
data = []
|
|
199
|
+
|
|
200
|
+
with torch.no_grad():
|
|
201
|
+
for batch in tqdm(batches, desc="embedding extraction"):
|
|
202
|
+
inputs = EMBEDDINGS_PROCESSOR(images=batch, return_tensors="pt").to(DEVICE)
|
|
203
|
+
outputs = EMBEDDINGS_MODEL(**inputs)
|
|
204
|
+
embeddings = torch.mean(outputs.last_hidden_state, dim=1).cpu().numpy()
|
|
205
|
+
data.append(embeddings)
|
|
206
|
+
data = np.concatenate(data)
|
|
207
|
+
return data
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def cluster_players_by_team(embeddings):
|
|
211
|
+
REDUCER = umap.UMAP(n_components=3)
|
|
212
|
+
CLUSTERING_MODEL = KMeans(n_clusters=2, n_init=10, random_state=42)
|
|
213
|
+
|
|
214
|
+
projections = REDUCER.fit_transform(embeddings)
|
|
215
|
+
clusters = CLUSTERING_MODEL.fit_predict(projections)
|
|
216
|
+
|
|
217
|
+
return projections, clusters
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def save_projection_plot_html(projections, clusters, crops):
|
|
221
|
+
# inline helper: convert a crop (PIL image) to base64 string
|
|
222
|
+
def pil_image_to_data_uri(image):
|
|
223
|
+
buffered = BytesIO()
|
|
224
|
+
image.save(buffered, format="PNG")
|
|
225
|
+
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
|
226
|
+
return f"data:image/png;base64,{img_str}"
|
|
227
|
+
|
|
228
|
+
# convert crops to base64-encoded URIs
|
|
229
|
+
image_data_uris = {
|
|
230
|
+
f"image_{i}": pil_image_to_data_uri(crop) for i, crop in enumerate(crops)
|
|
231
|
+
}
|
|
232
|
+
image_ids = np.array([f"image_{i}" for i in range(len(crops))])
|
|
233
|
+
|
|
234
|
+
traces = []
|
|
235
|
+
unique_clusters = np.unique(clusters)
|
|
236
|
+
for unique_cluster in unique_clusters:
|
|
237
|
+
mask = clusters == unique_cluster
|
|
238
|
+
trace = go.Scatter3d(
|
|
239
|
+
x=projections[mask][:, 0],
|
|
240
|
+
y=projections[mask][:, 1],
|
|
241
|
+
z=projections[mask][:, 2],
|
|
242
|
+
mode="markers+text", # hard-coded: markers+text
|
|
243
|
+
text=clusters[mask],
|
|
244
|
+
customdata=image_ids[mask],
|
|
245
|
+
name=str(unique_cluster),
|
|
246
|
+
marker=dict(size=6),
|
|
247
|
+
hovertemplate="<b>class: %{text}</b><br>image ID: %{customdata}<extra></extra>",
|
|
248
|
+
)
|
|
249
|
+
traces.append(trace)
|
|
250
|
+
|
|
251
|
+
# make cube axis range
|
|
252
|
+
min_val = np.min(projections)
|
|
253
|
+
max_val = np.max(projections)
|
|
254
|
+
padding = (max_val - min_val) * 0.05
|
|
255
|
+
axis_range = [min_val - padding, max_val + padding]
|
|
256
|
+
|
|
257
|
+
fig = go.Figure(data=traces)
|
|
258
|
+
fig.update_layout(
|
|
259
|
+
scene=dict(
|
|
260
|
+
xaxis=dict(title="X", range=axis_range),
|
|
261
|
+
yaxis=dict(title="Y", range=axis_range),
|
|
262
|
+
zaxis=dict(title="Z", range=axis_range),
|
|
263
|
+
aspectmode="cube",
|
|
264
|
+
),
|
|
265
|
+
width=1000,
|
|
266
|
+
height=1000,
|
|
267
|
+
showlegend=True,
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
# embed chart HTML with custom JS for crop preview
|
|
271
|
+
plotly_div = fig.to_html(
|
|
272
|
+
full_html=False, include_plotlyjs=True, div_id="scatter-plot-3d"
|
|
273
|
+
)
|
|
274
|
+
javascript_code = f"""
|
|
275
|
+
<script>
|
|
276
|
+
function displayImage(imageId) {{
|
|
277
|
+
var imageElement = document.getElementById('image-display');
|
|
278
|
+
var placeholderText = document.getElementById('placeholder-text');
|
|
279
|
+
var imageDataURIs = {image_data_uris};
|
|
280
|
+
imageElement.src = imageDataURIs[imageId];
|
|
281
|
+
imageElement.style.display = 'block';
|
|
282
|
+
placeholderText.style.display = 'none';
|
|
283
|
+
}}
|
|
284
|
+
|
|
285
|
+
var chartElement = document.getElementById('scatter-plot-3d');
|
|
286
|
+
chartElement.on('plotly_click', function(data) {{
|
|
287
|
+
var customdata = data.points[0].customdata;
|
|
288
|
+
displayImage(customdata);
|
|
289
|
+
}});
|
|
290
|
+
</script>
|
|
291
|
+
"""
|
|
292
|
+
|
|
293
|
+
html_template = f"""
|
|
294
|
+
<!DOCTYPE html>
|
|
295
|
+
<html>
|
|
296
|
+
<head>
|
|
297
|
+
<script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
|
|
298
|
+
<style>
|
|
299
|
+
#image-container {{
|
|
300
|
+
position: fixed;
|
|
301
|
+
top: 0;
|
|
302
|
+
left: 0;
|
|
303
|
+
width: 200px;
|
|
304
|
+
height: 200px;
|
|
305
|
+
padding: 5px;
|
|
306
|
+
border: 1px solid #ccc;
|
|
307
|
+
background-color: white;
|
|
308
|
+
z-index: 1000;
|
|
309
|
+
box-sizing: border-box;
|
|
310
|
+
display: flex;
|
|
311
|
+
align-items: center;
|
|
312
|
+
justify-content: center;
|
|
313
|
+
text-align: center;
|
|
314
|
+
}}
|
|
315
|
+
#image-display {{
|
|
316
|
+
width: 100%;
|
|
317
|
+
height: 100%;
|
|
318
|
+
object-fit: contain;
|
|
319
|
+
}}
|
|
320
|
+
</style>
|
|
321
|
+
</head>
|
|
322
|
+
<body>
|
|
323
|
+
{plotly_div}
|
|
324
|
+
<div id="image-container">
|
|
325
|
+
<img id="image-display" src="" alt="Selected image" style="display: none;" />
|
|
326
|
+
<p id="placeholder-text">Click a data point to display the image</p>
|
|
327
|
+
</div>
|
|
328
|
+
{javascript_code}
|
|
329
|
+
</body>
|
|
330
|
+
</html>
|
|
331
|
+
"""
|
|
332
|
+
|
|
333
|
+
out_path = "./runs/player_clusters.html"
|
|
334
|
+
with open(out_path, "w") as f:
|
|
335
|
+
f.write(html_template)
|
|
336
|
+
|
|
337
|
+
print(f"Interactive plot saved to {out_path}. Open it in your browser.")
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
def resolve_goalkeepers_team_id(players, goalkeepers):
|
|
341
|
+
players_xy = players.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
|
|
342
|
+
goalkeepers_xy = goalkeepers.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
|
|
343
|
+
|
|
344
|
+
team_0_centroid = players_xy[players.class_id == 0].mean(axis=0)
|
|
345
|
+
team_1_centroid = players_xy[players.class_id == 1].mean(axis=0)
|
|
346
|
+
goalkeepers_team_id = []
|
|
347
|
+
|
|
348
|
+
for gk_xy in goalkeepers_xy:
|
|
349
|
+
dist_0 = np.linalg.norm(gk_xy - team_0_centroid)
|
|
350
|
+
dist_1 = np.linalg.norm(gk_xy - team_1_centroid)
|
|
351
|
+
goalkeepers_team_id.append(0 if dist_0 < dist_1 else 1)
|
|
352
|
+
|
|
353
|
+
return np.array(goalkeepers_team_id)
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def inference_with_goalkeepers(video_path):
|
|
357
|
+
BALL_ID = 0
|
|
358
|
+
GOALKEEPER_ID = 1
|
|
359
|
+
PLAYER_ID = 2
|
|
360
|
+
REFEREE_ID = 3
|
|
361
|
+
|
|
362
|
+
crops = create_player_crops(video_path)
|
|
363
|
+
embeddings = extract_embeddings(crops)
|
|
364
|
+
|
|
365
|
+
REDUCER = umap.UMAP(n_components=3)
|
|
366
|
+
CLUSTERING_MODEL = KMeans(n_clusters=2, n_init=10, random_state=42)
|
|
367
|
+
|
|
368
|
+
projections = REDUCER.fit_transform(embeddings)
|
|
369
|
+
clustering_model = CLUSTERING_MODEL.fit(projections)
|
|
370
|
+
|
|
371
|
+
model_trained = YOLO("./runs/detect/train/weights/best.pt")
|
|
372
|
+
|
|
373
|
+
frame_generator = sv.get_video_frames_generator(video_path)
|
|
374
|
+
frame = next(frame_generator)
|
|
375
|
+
|
|
376
|
+
tracker = sv.ByteTrack()
|
|
377
|
+
|
|
378
|
+
cap = cv2.VideoCapture(video_path)
|
|
379
|
+
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
380
|
+
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
381
|
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
|
382
|
+
|
|
383
|
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
|
384
|
+
out = cv2.VideoWriter(
|
|
385
|
+
"./runs/video_tracked_2.mp4", fourcc, fps, (frame_width, frame_height)
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
for frame in frame_generator:
|
|
389
|
+
result = model_trained.predict(frame, conf=0.3)[0]
|
|
390
|
+
|
|
391
|
+
detections = sv.Detections.from_ultralytics(result)
|
|
392
|
+
|
|
393
|
+
ball_detections = detections[detections.class_id == BALL_ID]
|
|
394
|
+
ball_detections.xyxy = sv.pad_boxes(xyxy=ball_detections.xyxy, px=10)
|
|
395
|
+
|
|
396
|
+
all_detections = detections[detections.class_id != BALL_ID]
|
|
397
|
+
all_detections = all_detections.with_nms(threshold=0.5, class_agnostic=True)
|
|
398
|
+
all_detections = tracker.update_with_detections(detections=all_detections)
|
|
399
|
+
|
|
400
|
+
goalkeepers_detections = all_detections[
|
|
401
|
+
all_detections.class_id == GOALKEEPER_ID
|
|
402
|
+
]
|
|
403
|
+
players_detections = all_detections[all_detections.class_id == PLAYER_ID]
|
|
404
|
+
referees_detections = all_detections[all_detections.class_id == REFEREE_ID]
|
|
405
|
+
|
|
406
|
+
players_crops = [sv.crop_image(frame, xyxy) for xyxy in players_detections.xyxy]
|
|
407
|
+
player_embeddings = extract_embeddings(players_crops)
|
|
408
|
+
player_projection = REDUCER.transform(player_embeddings)
|
|
409
|
+
players_detections.class_id = clustering_model.predict(player_projection)
|
|
410
|
+
|
|
411
|
+
goalkeepers_detections.class_id = resolve_goalkeepers_team_id(
|
|
412
|
+
players_detections, goalkeepers_detections
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
referees_detections.class_id -= 1
|
|
416
|
+
|
|
417
|
+
all_detections = sv.Detections.merge(
|
|
418
|
+
[players_detections, goalkeepers_detections, referees_detections]
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
labels = [f"#{tracker_id}" for tracker_id in all_detections.tracker_id]
|
|
422
|
+
|
|
423
|
+
all_detections.class_id = all_detections.class_id.astype(int)
|
|
424
|
+
|
|
425
|
+
ellipse_annotator = sv.EllipseAnnotator(
|
|
426
|
+
color=sv.ColorPalette.from_hex(["#00BFFF", "#FF1493", "#FFD700"]),
|
|
427
|
+
thickness=2,
|
|
428
|
+
)
|
|
429
|
+
label_annotator = sv.LabelAnnotator(
|
|
430
|
+
color=sv.ColorPalette.from_hex(["#00BFFF", "#FF1493", "#FFD700"]),
|
|
431
|
+
text_color=sv.Color.from_hex("#000000"),
|
|
432
|
+
text_position=sv.Position.BOTTOM_CENTER,
|
|
433
|
+
)
|
|
434
|
+
triangle_annotator = sv.TriangleAnnotator(
|
|
435
|
+
color=sv.Color.from_hex("#FFD700"), base=25, height=21, outline_thickness=1
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
annotated_frame = frame.copy()
|
|
439
|
+
annotated_frame = ellipse_annotator.annotate(
|
|
440
|
+
scene=annotated_frame, detections=all_detections
|
|
441
|
+
)
|
|
442
|
+
annotated_frame = label_annotator.annotate(
|
|
443
|
+
scene=annotated_frame, detections=all_detections, labels=labels
|
|
444
|
+
)
|
|
445
|
+
annotated_frame = triangle_annotator.annotate(
|
|
446
|
+
scene=annotated_frame, detections=ball_detections
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
out.write(annotated_frame)
|
|
450
|
+
out.release()
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
def keypoint_detection_train():
|
|
454
|
+
print("Beginning training ...")
|
|
455
|
+
# This is to get the roboflow dataset and train.
|
|
456
|
+
rf = Roboflow(api_key=ROBOFLOW_API_KEY)
|
|
457
|
+
project = rf.workspace("roboflow-jvuqo").project("football-field-detection-f07vi")
|
|
458
|
+
version = project.version(15)
|
|
459
|
+
dataset = version.download("yolov8")
|
|
460
|
+
|
|
461
|
+
model_base_s = YOLO("yolo11m-pose.pt")
|
|
462
|
+
results = model_base_s.train(
|
|
463
|
+
data="./football-field-detection-15/data.yaml",
|
|
464
|
+
save=True,
|
|
465
|
+
epochs=300,
|
|
466
|
+
plots=True,
|
|
467
|
+
imgsz=1080,
|
|
468
|
+
device=0,
|
|
469
|
+
batch=2,
|
|
470
|
+
project="runs/keypoint_detect/",
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
|
|
474
|
+
def keypoint_detection_inference(video_path: str) -> None:
|
|
475
|
+
model = YOLO("./runs/keypoint_detect/train4/weights/best.pt")
|
|
476
|
+
vertex_annotator = sv.VertexAnnotator(
|
|
477
|
+
color=sv.Color.from_hex("#FF1493"),
|
|
478
|
+
radius=8,
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
frame = next(sv.get_video_frames_generator(video_path))
|
|
482
|
+
res = model.predict(frame, conf=0.3, iou=0.7, imgsz=1088, verbose=False)[0]
|
|
483
|
+
|
|
484
|
+
if res.boxes is None or len(res.boxes) == 0:
|
|
485
|
+
cv2.imwrite("./runs/annotated_frame_keypoint.jpg", frame)
|
|
486
|
+
return
|
|
487
|
+
|
|
488
|
+
best_i = int(np.argmax(res.boxes.conf.cpu().numpy()))
|
|
489
|
+
|
|
490
|
+
kxy = res.keypoints.xy[best_i].cpu().numpy()
|
|
491
|
+
kconf = res.keypoints.conf[best_i].cpu().numpy()
|
|
492
|
+
|
|
493
|
+
keep = kconf > 0.5
|
|
494
|
+
kxy = kxy[keep]
|
|
495
|
+
|
|
496
|
+
kp = sv.KeyPoints(xy=kxy[np.newaxis, ...])
|
|
497
|
+
|
|
498
|
+
out = vertex_annotator.annotate(scene=frame.copy(), key_points=kp)
|
|
499
|
+
cv2.imwrite("./runs/annotated_frame_keypoint.jpg", out)
|
|
500
|
+
|
|
501
|
+
|
|
502
|
+
def keypoint_detection_indices(video_path, output_path="keypoint_debug.mp4"):
|
|
503
|
+
print(f"DEBUG: Starting keypoint visualization for {video_path}")
|
|
504
|
+
|
|
505
|
+
model_trained = YOLO("./runs/keypoint_detect/train4/weights/best.pt")
|
|
506
|
+
|
|
507
|
+
cap = cv2.VideoCapture(video_path)
|
|
508
|
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
509
|
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
510
|
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
|
511
|
+
|
|
512
|
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
|
513
|
+
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
|
514
|
+
|
|
515
|
+
frame_count = 0
|
|
516
|
+
|
|
517
|
+
while True:
|
|
518
|
+
ret, frame = cap.read()
|
|
519
|
+
if not ret:
|
|
520
|
+
break
|
|
521
|
+
|
|
522
|
+
frame_count += 1
|
|
523
|
+
if frame_count % 30 == 0:
|
|
524
|
+
print(f"Processing frame {frame_count}...")
|
|
525
|
+
|
|
526
|
+
result = model_trained.predict(frame, conf=0.5, verbose=False)[0]
|
|
527
|
+
|
|
528
|
+
if result.keypoints is not None and result.keypoints.xy.shape[0] > 0:
|
|
529
|
+
kpts_xy = result.keypoints.xy[0].cpu().numpy()
|
|
530
|
+
kpts_conf = result.keypoints.conf[0].cpu().numpy()
|
|
531
|
+
|
|
532
|
+
for i, (x, y) in enumerate(kpts_xy):
|
|
533
|
+
conf = kpts_conf[i]
|
|
534
|
+
|
|
535
|
+
if conf < 0.5 or (x == 0 and y == 0):
|
|
536
|
+
continue
|
|
537
|
+
|
|
538
|
+
x, y = int(x), int(y)
|
|
539
|
+
cv2.circle(frame, (x, y), 5, (0, 0, 255), -1)
|
|
540
|
+
cv2.putText(
|
|
541
|
+
frame,
|
|
542
|
+
str(i + 1),
|
|
543
|
+
(x + 10, y),
|
|
544
|
+
cv2.FONT_HERSHEY_SIMPLEX,
|
|
545
|
+
0.6,
|
|
546
|
+
(0, 255, 255),
|
|
547
|
+
2,
|
|
548
|
+
)
|
|
549
|
+
else:
|
|
550
|
+
cv2.putText(
|
|
551
|
+
frame,
|
|
552
|
+
"NO PITCH DETECTED",
|
|
553
|
+
(50, 50),
|
|
554
|
+
cv2.FONT_HERSHEY_SIMPLEX,
|
|
555
|
+
1,
|
|
556
|
+
(0, 0, 255),
|
|
557
|
+
2,
|
|
558
|
+
)
|
|
559
|
+
out.write(frame)
|
|
560
|
+
|
|
561
|
+
cap.release()
|
|
562
|
+
out.release()
|
|
563
|
+
print("Video generation complete.")
|
|
564
|
+
|
|
565
|
+
|
|
566
|
+
def keypoint_detection_val_test(model_path, data_yaml):
|
|
567
|
+
model = YOLO(model_path)
|
|
568
|
+
results = model.val(data=data_yaml, split="test", plots=True)
|
|
569
|
+
|
|
570
|
+
print(results)
|
|
571
|
+
|
|
572
|
+
|
|
573
|
+
def main():
|
|
574
|
+
print("CUDA available:", torch.cuda.is_available())
|
|
575
|
+
if torch.cuda.is_available():
|
|
576
|
+
print("GPU name:", torch.cuda.get_device_name(0))
|
|
577
|
+
|
|
578
|
+
video_path = "/home/lucas/Documents/dev/local/yolo_torch/video.mp4"
|
|
579
|
+
|
|
580
|
+
# train()
|
|
581
|
+
# inference(video_path)
|
|
582
|
+
# inference_with_player_tracking(video_path)
|
|
583
|
+
# crops = create_player_crops(video_path)
|
|
584
|
+
# embeddings = extract_embeddings(crops)
|
|
585
|
+
# projections, clusters = cluster_players_by_team(embeddings)
|
|
586
|
+
# save_projection_plot_html(projections, clusters, crops)
|
|
587
|
+
# inference_with_goalkeepers(video_path)
|
|
588
|
+
# keypoint_detection_train()
|
|
589
|
+
# keypoint_detection_inference(video_path)
|
|
590
|
+
keypoint_detection_indices(video_path)
|
|
591
|
+
# val_test(
|
|
592
|
+
# "./runs/keypoint_detect/train4/weights/best.pt",
|
|
593
|
+
# "./football-field-detection-15/data.yaml",
|
|
594
|
+
# )
|
|
595
|
+
|
|
596
|
+
|
|
597
|
+
if __name__ == "__main__":
|
|
598
|
+
main()
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import networkx as nx
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
import os
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
# Load match events JSON (replace with your actual file path)
|
|
9
|
+
with open(
|
|
10
|
+
"./data/3825818.json",
|
|
11
|
+
"r",
|
|
12
|
+
) as f:
|
|
13
|
+
events = json.load(f)
|
|
14
|
+
|
|
15
|
+
# Build mapping from player ID -> player name from the Starting XI event
|
|
16
|
+
player_id_to_name = {}
|
|
17
|
+
|
|
18
|
+
for ev in events:
|
|
19
|
+
if ev["type"]["name"] == "Starting XI":
|
|
20
|
+
for lineup in ev["tactics"]["lineup"]:
|
|
21
|
+
pid = lineup["player"]["id"]
|
|
22
|
+
name = lineup["player"]["name"]
|
|
23
|
+
player_id_to_name[pid] = name
|
|
24
|
+
|
|
25
|
+
# Data structures for passes and positions
|
|
26
|
+
edges = defaultdict(int) # (passer, recipient) -> count of passes
|
|
27
|
+
player_positions = defaultdict(list) # player_id -> list of [x, y] positions
|
|
28
|
+
|
|
29
|
+
TEAM_NAME = "Real Sociedad"
|
|
30
|
+
|
|
31
|
+
# Extract completed passes
|
|
32
|
+
for ev in events:
|
|
33
|
+
if ev["type"]["name"] == "Pass" and ev["team"]["name"] == TEAM_NAME:
|
|
34
|
+
passer = ev["player"]["id"]
|
|
35
|
+
recipient = ev.get("pass", {}).get("recipient", {}).get("id")
|
|
36
|
+
outcome = ev.get("pass", {}).get("outcome", {"name": "Complete"})["name"]
|
|
37
|
+
|
|
38
|
+
if outcome == "Complete" and recipient is not None:
|
|
39
|
+
edges[(passer, recipient)] += 1
|
|
40
|
+
start = ev["location"]
|
|
41
|
+
end = ev["pass"]["end_location"]
|
|
42
|
+
player_positions[passer].append(start)
|
|
43
|
+
player_positions[recipient].append(end)
|
|
44
|
+
|
|
45
|
+
# Calculate avg positions
|
|
46
|
+
avg_positions = {}
|
|
47
|
+
for player_id, coords in player_positions.items():
|
|
48
|
+
xs = [pt[0] for pt in coords]
|
|
49
|
+
ys = [pt[1] for pt in coords]
|
|
50
|
+
avg_positions[player_id] = [sum(xs) / len(xs), sum(ys) / len(ys)]
|
|
51
|
+
|
|
52
|
+
# Build a JSON-friendly structure for export (nodes + links)
|
|
53
|
+
nodes = [{"id": pid, "x": pos[0], "y": pos[1]} for pid, pos in avg_positions.items()]
|
|
54
|
+
links = [
|
|
55
|
+
{"source": src, "target": tgt, "value": count}
|
|
56
|
+
for (src, tgt), count in edges.items()
|
|
57
|
+
]
|
|
58
|
+
|
|
59
|
+
network = {"nodes": nodes, "links": links}
|
|
60
|
+
os.makedirs("pass_network", exist_ok=True)
|
|
61
|
+
with open("./pass_network/pass_network.json", "w") as f:
|
|
62
|
+
json.dump(network, f, indent=2)
|
|
63
|
+
|
|
64
|
+
# Build NetworkX graph
|
|
65
|
+
G = nx.DiGraph()
|
|
66
|
+
|
|
67
|
+
# Add nodes with positions
|
|
68
|
+
for pid, pos in avg_positions.items():
|
|
69
|
+
G.add_node(pid, pos=(pos[0], pos[1]))
|
|
70
|
+
|
|
71
|
+
# Add edges with weights
|
|
72
|
+
for (src, tgt), count in edges.items():
|
|
73
|
+
G.add_edge(src, tgt, weight=count)
|
|
74
|
+
|
|
75
|
+
# Draw graph
|
|
76
|
+
pos = nx.get_node_attributes(G, "pos")
|
|
77
|
+
labels = {pid: player_id_to_name.get(pid, str(pid)) for pid in G.nodes()}
|
|
78
|
+
|
|
79
|
+
fig, ax = plt.subplots(figsize=(10, 7))
|
|
80
|
+
|
|
81
|
+
# Draw pitch outline
|
|
82
|
+
ax.set_xlim(0, 120)
|
|
83
|
+
ax.set_ylim(0, 80)
|
|
84
|
+
ax.plot([0, 120, 120, 0, 0], [0, 0, 80, 80, 0], color="black")
|
|
85
|
+
|
|
86
|
+
# Draw nodes
|
|
87
|
+
nx.draw_networkx_nodes(G, pos, ax=ax, node_color="skyblue", node_size=500)
|
|
88
|
+
|
|
89
|
+
# Draw edges
|
|
90
|
+
nx.draw_networkx_edges(
|
|
91
|
+
G,
|
|
92
|
+
pos,
|
|
93
|
+
ax=ax,
|
|
94
|
+
width=[d["weight"] * 0.2 for _, _, d in G.edges(data=True)],
|
|
95
|
+
alpha=0.7,
|
|
96
|
+
arrowsize=10,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
# Draw player names
|
|
100
|
+
nx.draw_networkx_labels(G, pos, labels=labels, ax=ax, font_size=8)
|
|
101
|
+
|
|
102
|
+
plt.title("Team Pass Network")
|
|
103
|
+
plt.savefig("./pass_network/pass_network_viz.png", dpi=150, bbox_inches="tight")
|
wunderscout/teams.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import umap
|
|
3
|
+
from sklearn.cluster import KMeans
|
|
4
|
+
import supervision as sv
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class TeamClassifier:
|
|
8
|
+
def __init__(self):
|
|
9
|
+
self.reducer = umap.UMAP(n_components=3)
|
|
10
|
+
self.clusterer = KMeans(n_clusters=2, n_init=10, random_state=42)
|
|
11
|
+
self.history = {}
|
|
12
|
+
|
|
13
|
+
def fit(self, embeddings):
|
|
14
|
+
projections = self.reducer.fit_transform(embeddings)
|
|
15
|
+
self.clusterer.fit(projections)
|
|
16
|
+
|
|
17
|
+
def get_consensus_team(self, tracker_id, embedding):
|
|
18
|
+
proj = self.reducer.transform(embedding.reshape(1, -1))
|
|
19
|
+
pred = self.clusterer.predict(proj)[0]
|
|
20
|
+
|
|
21
|
+
if tracker_id not in self.history:
|
|
22
|
+
self.history[tracker_id] = []
|
|
23
|
+
self.history[tracker_id].append(pred)
|
|
24
|
+
if len(self.history[tracker_id]) > 50:
|
|
25
|
+
self.history[tracker_id].pop(0)
|
|
26
|
+
|
|
27
|
+
return (
|
|
28
|
+
1
|
|
29
|
+
if (sum(self.history[tracker_id]) / len(self.history[tracker_id])) > 0.5
|
|
30
|
+
else 0
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
def resolve_goalkeepers_team_id(self, players, goalkeepers):
|
|
34
|
+
"""
|
|
35
|
+
Assigns goalkeepers to the team whose centroid is closest.
|
|
36
|
+
players: sv.Detections (already classified with class_id 0 or 1)
|
|
37
|
+
goalkeepers: sv.Detections
|
|
38
|
+
"""
|
|
39
|
+
if len(players) == 0 or len(goalkeepers) == 0:
|
|
40
|
+
return np.array([0] * len(goalkeepers))
|
|
41
|
+
|
|
42
|
+
players_xy = players.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
|
|
43
|
+
goalkeepers_xy = goalkeepers.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
|
|
44
|
+
|
|
45
|
+
# Calculate centroids for Team 0 and Team 1
|
|
46
|
+
team_0_mask = players.class_id == 0
|
|
47
|
+
team_1_mask = players.class_id == 1
|
|
48
|
+
|
|
49
|
+
# Handle cases where one team might not be detected yet
|
|
50
|
+
if np.any(team_0_mask):
|
|
51
|
+
team_0_centroid = players_xy[team_0_mask].mean(axis=0)
|
|
52
|
+
else:
|
|
53
|
+
team_0_centroid = np.array([0, 0])
|
|
54
|
+
|
|
55
|
+
if np.any(team_1_mask):
|
|
56
|
+
team_1_centroid = players_xy[team_1_mask].mean(axis=0)
|
|
57
|
+
else:
|
|
58
|
+
team_1_centroid = np.array([10000, 10000]) # Far away
|
|
59
|
+
|
|
60
|
+
goalkeepers_team_id = []
|
|
61
|
+
|
|
62
|
+
for gk_xy in goalkeepers_xy:
|
|
63
|
+
dist_0 = np.linalg.norm(gk_xy - team_0_centroid)
|
|
64
|
+
dist_1 = np.linalg.norm(gk_xy - team_1_centroid)
|
|
65
|
+
goalkeepers_team_id.append(0 if dist_0 < dist_1 else 1)
|
|
66
|
+
|
|
67
|
+
return np.array(goalkeepers_team_id)
|
|
68
|
+
|
|
69
|
+
def get_final_assignments(self):
|
|
70
|
+
assignments = {}
|
|
71
|
+
for tid, votes in self.history.items():
|
|
72
|
+
if len(votes) > 0:
|
|
73
|
+
avg = sum(votes) / len(votes)
|
|
74
|
+
assignments[tid] = 1 if avg > 0.5 else 0
|
|
75
|
+
return assignments
|
|
76
|
+
|
wunderscout/vision.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from ultralytics import YOLO
|
|
3
|
+
import supervision as sv
|
|
4
|
+
from transformers import AutoProcessor, SiglipVisionModel, data
|
|
5
|
+
from roboflow import Roboflow
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
from more_itertools import chunked
|
|
8
|
+
import numpy as np
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class VisionEngine:
|
|
13
|
+
def __init__(self, player_weights, field_weights, device=None):
|
|
14
|
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
15
|
+
self.player_model = YOLO(player_weights)
|
|
16
|
+
self.field_model = YOLO(field_weights)
|
|
17
|
+
|
|
18
|
+
# Siglip for embeddings
|
|
19
|
+
siglip_path = "google/siglip-base-patch16-224"
|
|
20
|
+
self.siglip_model = SiglipVisionModel.from_pretrained(siglip_path).to(
|
|
21
|
+
self.device
|
|
22
|
+
)
|
|
23
|
+
self.siglip_processor = AutoProcessor.from_pretrained(siglip_path)
|
|
24
|
+
|
|
25
|
+
# --- Annotators ---
|
|
26
|
+
# Palette: 0=Blue, 1=Pink, 2=Yellow (Referee)
|
|
27
|
+
self.palette = sv.ColorPalette.from_hex(["#00BFFF", "#FF1493", "#FFD700"])
|
|
28
|
+
|
|
29
|
+
self.ellipse_annotator = sv.EllipseAnnotator(
|
|
30
|
+
color=self.palette,
|
|
31
|
+
thickness=2,
|
|
32
|
+
)
|
|
33
|
+
self.label_annotator = sv.LabelAnnotator(
|
|
34
|
+
color=self.palette,
|
|
35
|
+
text_color=sv.Color.from_hex("#000000"),
|
|
36
|
+
text_position=sv.Position.BOTTOM_CENTER,
|
|
37
|
+
)
|
|
38
|
+
self.triangle_annotator = sv.TriangleAnnotator(
|
|
39
|
+
color=sv.Color.from_hex("#FFD700"), base=25, height=21, outline_thickness=1
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
def get_calibration_crops(self, video_path, stride=30):
|
|
43
|
+
PLAYER_ID = 2
|
|
44
|
+
frame_generator = sv.get_video_frames_generator(
|
|
45
|
+
source_path=video_path, stride=stride
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
crops = []
|
|
49
|
+
for frame in frame_generator:
|
|
50
|
+
detections = self.detect_players(frame)
|
|
51
|
+
# Filter for players only for calibration
|
|
52
|
+
players = detections[detections.class_id == PLAYER_ID]
|
|
53
|
+
frame_crops = [sv.crop_image(frame, xyxy) for xyxy in players.xyxy]
|
|
54
|
+
crops += [sv.cv2_to_pillow(c) for c in frame_crops]
|
|
55
|
+
|
|
56
|
+
print(f"VisionEngine: Collected {len(crops)} calibration crops.")
|
|
57
|
+
return crops
|
|
58
|
+
|
|
59
|
+
def get_embeddings(self, pil_crops, batch_size=32):
|
|
60
|
+
batches = chunked(pil_crops, batch_size)
|
|
61
|
+
data_list = []
|
|
62
|
+
|
|
63
|
+
with torch.no_grad():
|
|
64
|
+
for batch in batches:
|
|
65
|
+
inputs = self.siglip_processor(images=batch, return_tensors="pt").to(
|
|
66
|
+
self.device
|
|
67
|
+
)
|
|
68
|
+
outputs = self.siglip_model(**inputs)
|
|
69
|
+
embeddings = torch.mean(outputs.last_hidden_state, dim=1).cpu().numpy()
|
|
70
|
+
data_list.append(embeddings)
|
|
71
|
+
|
|
72
|
+
return np.concatenate(data_list) if data_list else np.array([])
|
|
73
|
+
|
|
74
|
+
def detect_players(self, frame, conf=0.3):
|
|
75
|
+
result = self.player_model.predict(frame, conf=conf, verbose=False)[0]
|
|
76
|
+
return sv.Detections.from_ultralytics(result)
|
|
77
|
+
|
|
78
|
+
def detect_field(self, frame, conf=0.3):
|
|
79
|
+
result = self.field_model.predict(frame, conf=conf, verbose=False)[0]
|
|
80
|
+
return result
|
|
81
|
+
|
|
82
|
+
def draw_annotations(self, frame, all_detections, ball_detections):
|
|
83
|
+
annotated_frame = frame.copy()
|
|
84
|
+
|
|
85
|
+
# 1. Draw Ball
|
|
86
|
+
annotated_frame = self.triangle_annotator.annotate(
|
|
87
|
+
scene=annotated_frame, detections=ball_detections
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
# 2. Draw People (Players, GKs, Refs)
|
|
91
|
+
if len(all_detections) > 0:
|
|
92
|
+
# Ensure class_id is int for color mapping
|
|
93
|
+
all_detections.class_id = all_detections.class_id.astype(int)
|
|
94
|
+
|
|
95
|
+
labels = [f"#{tracker_id}" for tracker_id in all_detections.tracker_id]
|
|
96
|
+
|
|
97
|
+
annotated_frame = self.ellipse_annotator.annotate(
|
|
98
|
+
scene=annotated_frame, detections=all_detections
|
|
99
|
+
)
|
|
100
|
+
annotated_frame = self.label_annotator.annotate(
|
|
101
|
+
scene=annotated_frame, detections=all_detections, labels=labels
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
return annotated_frame
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class ScoutingTrainer:
|
|
108
|
+
def __init__(self, api_key):
|
|
109
|
+
self.rf = Roboflow(api_key=api_key)
|
|
110
|
+
|
|
111
|
+
def train_players(
|
|
112
|
+
self,
|
|
113
|
+
workspace,
|
|
114
|
+
project,
|
|
115
|
+
version,
|
|
116
|
+
epochs=300,
|
|
117
|
+
output_dir="../runs/training/player",
|
|
118
|
+
):
|
|
119
|
+
project = self.rf.workspace(workspace).project(project)
|
|
120
|
+
dataset = project.version(version).download("yolov11")
|
|
121
|
+
model = YOLO("../data/base_models/yolo11m.pt")
|
|
122
|
+
|
|
123
|
+
return model.train(
|
|
124
|
+
data=f"{dataset.location}/data.yaml",
|
|
125
|
+
epochs=epochs,
|
|
126
|
+
imgsz=1280,
|
|
127
|
+
plots=True,
|
|
128
|
+
device=0,
|
|
129
|
+
batch=2,
|
|
130
|
+
project=output_dir,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
def train_field(
|
|
134
|
+
self,
|
|
135
|
+
workspace,
|
|
136
|
+
project,
|
|
137
|
+
version,
|
|
138
|
+
epochs=300,
|
|
139
|
+
output_dir="../runs/training/field",
|
|
140
|
+
):
|
|
141
|
+
project = self.rf.workspace(workspace).project(project)
|
|
142
|
+
version = project.version(15)
|
|
143
|
+
dataset = version.download("yolov8", location="../data/data_sets/")
|
|
144
|
+
model = YOLO("yolo11m-pose.pt")
|
|
145
|
+
|
|
146
|
+
return model.train(
|
|
147
|
+
data=f"{dataset.location}/data.yaml",
|
|
148
|
+
save=True,
|
|
149
|
+
epochs=epochs,
|
|
150
|
+
plots=True,
|
|
151
|
+
imgsz=1080,
|
|
152
|
+
device=0,
|
|
153
|
+
batch=2,
|
|
154
|
+
project=output_dir,
|
|
155
|
+
)
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: wunderscout
|
|
3
|
+
Version: 0.1.2
|
|
4
|
+
Summary: Add your description here
|
|
5
|
+
Requires-Python: >=3.11
|
|
6
|
+
Requires-Dist: ipython>=9.5.0
|
|
7
|
+
Requires-Dist: matplotlib>=3.10.6
|
|
8
|
+
Requires-Dist: memory-profiler>=0.61.0
|
|
9
|
+
Requires-Dist: more-itertools>=10.8.0
|
|
10
|
+
Requires-Dist: networkx>=3.5
|
|
11
|
+
Requires-Dist: numba>=0.58
|
|
12
|
+
Requires-Dist: numpy>=2.3.2
|
|
13
|
+
Requires-Dist: opencv-python>=4.11.0.86
|
|
14
|
+
Requires-Dist: pandas>=2.3.2
|
|
15
|
+
Requires-Dist: plotly>=6.3.0
|
|
16
|
+
Requires-Dist: protobuf>=6.32.1
|
|
17
|
+
Requires-Dist: psutil>=7.0.0
|
|
18
|
+
Requires-Dist: python-dotenv>=1.1.1
|
|
19
|
+
Requires-Dist: roboflow>=1.2.7
|
|
20
|
+
Requires-Dist: scikit-learn>=1.7.2
|
|
21
|
+
Requires-Dist: seaborn>=0.13.2
|
|
22
|
+
Requires-Dist: sentencepiece>=0.2.1
|
|
23
|
+
Requires-Dist: supervision>=0.26.1
|
|
24
|
+
Requires-Dist: tqdm>=4.67.1
|
|
25
|
+
Requires-Dist: transformers>=4.56.1
|
|
26
|
+
Requires-Dist: ultralytics>=8.3.193
|
|
27
|
+
Requires-Dist: umap-learn>=0.5.9.post2
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
wunderscout/__init__.py,sha256=UwZMG43lwIAxHCHPg8SYS9A2Azkem9vcZw6tkwFG9kU,217
|
|
2
|
+
wunderscout/core.py,sha256=suVWCsiVmwsevte8tsX_53GEdZ_cQhKcFTknY20BWEw,6345
|
|
3
|
+
wunderscout/exporters.py,sha256=CRczFAYUS6EOedL84wq9WaFWTPKJyGxL7kbZQvLzLOA,1875
|
|
4
|
+
wunderscout/geometry.py,sha256=I5lt00O9jOiEoVpPGy5iVglzA7cgaAdUvzfuFBcJbRA,2197
|
|
5
|
+
wunderscout/heatmap.py,sha256=5R8Zw5Bnk-8eHTcWudo-1a3Mt83WrSbqYX9MkUoht_8,4268
|
|
6
|
+
wunderscout/main.py,sha256=mkiZcWgAuAAc8bUmKQXLkw--PX8dwrRR6qKHZqmyi5M,20008
|
|
7
|
+
wunderscout/pass_network.py,sha256=QC859Pi5VKSgHu1qrE3Zybvu97lNorsb03UAf1IrSbs,3099
|
|
8
|
+
wunderscout/teams.py,sha256=y0IclDACo3F8buVdpqqMCSmZJeWx2uqMkGNbZ6YToVc,2628
|
|
9
|
+
wunderscout/vision.py,sha256=fVX3wtwCwe6AiiGxZjH8u4q2gk3t4gCb4MNvmQf7Lhs,5257
|
|
10
|
+
wunderscout-0.1.2.dist-info/METADATA,sha256=LUmQrjy8MaxHjTmCFILvR95dSQ8kWTYfpxpJPWudjx4,841
|
|
11
|
+
wunderscout-0.1.2.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
12
|
+
wunderscout-0.1.2.dist-info/RECORD,,
|