wunderscout 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wunderscout/__init__.py +15 -0
- wunderscout/core.py +167 -0
- wunderscout/data.py +33 -0
- wunderscout/exporters.py +42 -0
- wunderscout/geometry.py +74 -0
- wunderscout/heatmap.py +115 -0
- wunderscout/heatmaps.py +271 -0
- wunderscout/pass_network.py +103 -0
- wunderscout/teams.py +76 -0
- wunderscout/vision.py +155 -0
- wunderscout-0.1.11.dist-info/METADATA +87 -0
- wunderscout-0.1.11.dist-info/RECORD +13 -0
- wunderscout-0.1.11.dist-info/WHEEL +4 -0
wunderscout/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from .vision import VisionEngine
|
|
2
|
+
from .geometry import PitchMapper
|
|
3
|
+
from .teams import TeamClassifier
|
|
4
|
+
from .core import ScoutingPipeline
|
|
5
|
+
from .exporters import DataExporter
|
|
6
|
+
from .heatmaps import HeatmapGenerator
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"VisionEngine",
|
|
10
|
+
"PitchMapper",
|
|
11
|
+
"TeamClassifier",
|
|
12
|
+
"ScoutingPipeline",
|
|
13
|
+
"DataExporter",
|
|
14
|
+
"HeatmapGenerator",
|
|
15
|
+
]
|
wunderscout/core.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
import cv2
|
|
2
|
+
import supervision as sv
|
|
3
|
+
import numpy as np
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from .vision import VisionEngine
|
|
6
|
+
from .geometry import PitchMapper
|
|
7
|
+
from .teams import TeamClassifier
|
|
8
|
+
from .exporters import DataExporter
|
|
9
|
+
from .data import TrackingResult
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ScoutingPipeline:
|
|
13
|
+
def __init__(self, player_weights, field_weights):
|
|
14
|
+
self.engine = VisionEngine(player_weights, field_weights)
|
|
15
|
+
self.mapper = PitchMapper()
|
|
16
|
+
self.classifier = TeamClassifier()
|
|
17
|
+
|
|
18
|
+
def run(self, video_path, output_video_path=None):
|
|
19
|
+
# 1. Warm-up (Calibration)
|
|
20
|
+
print("WORKER: Calibrating teams...")
|
|
21
|
+
crops = self.engine.get_calibration_crops(video_path)
|
|
22
|
+
if len(crops) > 0:
|
|
23
|
+
embeddings = self.engine.get_embeddings(crops)
|
|
24
|
+
self.classifier.fit(embeddings)
|
|
25
|
+
else:
|
|
26
|
+
print("WARNING: No player crops found for calibration.")
|
|
27
|
+
|
|
28
|
+
# 2. Setup Video I/O
|
|
29
|
+
cap = cv2.VideoCapture(video_path)
|
|
30
|
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
|
31
|
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
32
|
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
33
|
+
|
|
34
|
+
out = None
|
|
35
|
+
if output_video_path:
|
|
36
|
+
output_path_obj = Path(output_video_path)
|
|
37
|
+
output_path_obj.parent.mkdir(parents=True, exist_ok=True)
|
|
38
|
+
out = cv2.VideoWriter(
|
|
39
|
+
output_video_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height)
|
|
40
|
+
)
|
|
41
|
+
if not out.isOpened():
|
|
42
|
+
print(f"ERROR: Could not create video file at {output_video_path}")
|
|
43
|
+
out = None
|
|
44
|
+
|
|
45
|
+
tracker = sv.ByteTrack()
|
|
46
|
+
tracking_results = {}
|
|
47
|
+
|
|
48
|
+
# ID Constants
|
|
49
|
+
BALL_ID = 0
|
|
50
|
+
GOALKEEPER_ID = 1
|
|
51
|
+
PLAYER_ID = 2
|
|
52
|
+
REFEREE_ID = 3
|
|
53
|
+
|
|
54
|
+
# 3. Main Processing Loop
|
|
55
|
+
print(f"WORKER: Starting processing: {video_path}")
|
|
56
|
+
frame_generator = sv.get_video_frames_generator(video_path)
|
|
57
|
+
|
|
58
|
+
frame_idx = -1
|
|
59
|
+
for frame_idx, frame in enumerate(frame_generator):
|
|
60
|
+
print(f"WORKER: Processing frame {frame_idx}")
|
|
61
|
+
|
|
62
|
+
# --- A. DETECTION ---
|
|
63
|
+
all_dets = self.engine.detect_players(frame)
|
|
64
|
+
f_res = self.engine.detect_field(frame)
|
|
65
|
+
|
|
66
|
+
# --- B. FIELD HOMOGRAPHY ---
|
|
67
|
+
H = None
|
|
68
|
+
if f_res.keypoints is not None and len(f_res.keypoints.xy) > 0:
|
|
69
|
+
H = self.mapper.get_matrix(
|
|
70
|
+
f_res.keypoints.xy[0].cpu().numpy(),
|
|
71
|
+
f_res.keypoints.conf[0].cpu().numpy(),
|
|
72
|
+
)
|
|
73
|
+
else:
|
|
74
|
+
H = self.mapper.last_h
|
|
75
|
+
|
|
76
|
+
# --- C. SEPARATE BALL & OTHERS ---
|
|
77
|
+
ball_detections = all_dets[all_dets.class_id == BALL_ID]
|
|
78
|
+
ball_detections.xyxy = sv.pad_boxes(xyxy=ball_detections.xyxy, px=10)
|
|
79
|
+
|
|
80
|
+
other_detections = all_dets[all_dets.class_id != BALL_ID]
|
|
81
|
+
other_detections = other_detections.with_nms(threshold=0.5)
|
|
82
|
+
|
|
83
|
+
# --- D. TRACKING ---
|
|
84
|
+
tracked_objects = tracker.update_with_detections(other_detections)
|
|
85
|
+
|
|
86
|
+
# Split tracked objects
|
|
87
|
+
tracked_players = tracked_objects[tracked_objects.class_id == PLAYER_ID]
|
|
88
|
+
tracked_gks = tracked_objects[tracked_objects.class_id == GOALKEEPER_ID]
|
|
89
|
+
tracked_refs = tracked_objects[tracked_objects.class_id == REFEREE_ID]
|
|
90
|
+
|
|
91
|
+
# --- E. TEAM CLASSIFICATION ---
|
|
92
|
+
|
|
93
|
+
# 1. Players
|
|
94
|
+
if len(tracked_players) > 0:
|
|
95
|
+
p_crops = [sv.crop_image(frame, xyxy) for xyxy in tracked_players.xyxy]
|
|
96
|
+
p_pil = [sv.cv2_to_pillow(c) for c in p_crops]
|
|
97
|
+
p_embeddings = self.engine.get_embeddings(p_pil)
|
|
98
|
+
|
|
99
|
+
final_team_ids = []
|
|
100
|
+
for i, tid in enumerate(tracked_players.tracker_id):
|
|
101
|
+
team_id = self.classifier.get_consensus_team(tid, p_embeddings[i])
|
|
102
|
+
final_team_ids.append(team_id)
|
|
103
|
+
|
|
104
|
+
tracked_players.class_id = np.array(final_team_ids)
|
|
105
|
+
|
|
106
|
+
# 2. Goalkeepers
|
|
107
|
+
if len(tracked_gks) > 0 and len(tracked_players) > 0:
|
|
108
|
+
tracked_gks.class_id = self.classifier.resolve_goalkeepers_team_id(
|
|
109
|
+
tracked_players, tracked_gks
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# 3. Referees (Shift ID 3 -> 2)
|
|
113
|
+
if len(tracked_refs) > 0:
|
|
114
|
+
tracked_refs.class_id -= 1
|
|
115
|
+
|
|
116
|
+
# --- F. DATA STORAGE ---
|
|
117
|
+
tracking_results[frame_idx] = {"players": {}, "ball": None}
|
|
118
|
+
data_targets = sv.Detections.merge([tracked_players, tracked_gks])
|
|
119
|
+
|
|
120
|
+
if H is not None:
|
|
121
|
+
if len(data_targets) > 0:
|
|
122
|
+
feet_coords = data_targets.get_anchors_coordinates(
|
|
123
|
+
sv.Position.BOTTOM_CENTER
|
|
124
|
+
)
|
|
125
|
+
transformed_feet = self.mapper.transform(feet_coords, H)
|
|
126
|
+
|
|
127
|
+
for i, tid in enumerate(data_targets.tracker_id):
|
|
128
|
+
px, py = transformed_feet[i]
|
|
129
|
+
tracking_results[frame_idx]["players"][tid] = (
|
|
130
|
+
max(0.0, min(1.0, px)),
|
|
131
|
+
max(0.0, min(1.0, py)),
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
if len(ball_detections) > 0:
|
|
135
|
+
ball_coords = ball_detections.get_anchors_coordinates(
|
|
136
|
+
sv.Position.CENTER
|
|
137
|
+
)
|
|
138
|
+
transformed_ball = self.mapper.transform([ball_coords[0]], H)
|
|
139
|
+
bx, by = transformed_ball[0]
|
|
140
|
+
tracking_results[frame_idx]["ball"] = (
|
|
141
|
+
max(0.0, min(1.0, bx)),
|
|
142
|
+
max(0.0, min(1.0, by)),
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# --- G. DRAW & WRITE VIDEO ---
|
|
146
|
+
if out:
|
|
147
|
+
all_tracked = sv.Detections.merge(
|
|
148
|
+
[tracked_players, tracked_gks, tracked_refs]
|
|
149
|
+
)
|
|
150
|
+
annotated_frame = self.engine.draw_annotations(
|
|
151
|
+
frame, all_tracked, ball_detections
|
|
152
|
+
)
|
|
153
|
+
out.write(annotated_frame)
|
|
154
|
+
|
|
155
|
+
# 4. Cleanup
|
|
156
|
+
if out:
|
|
157
|
+
out.release()
|
|
158
|
+
print(f"WORKER: Video saved to {output_video_path}")
|
|
159
|
+
cap.release()
|
|
160
|
+
|
|
161
|
+
# 5. Return data
|
|
162
|
+
return TrackingResult(
|
|
163
|
+
frames=tracking_results,
|
|
164
|
+
team_assignments=self.classifier.get_final_assignments(),
|
|
165
|
+
total_frames=frame_idx + 1,
|
|
166
|
+
fps=fps,
|
|
167
|
+
)
|
wunderscout/data.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
@dataclass
|
|
5
|
+
class TrackingResult:
|
|
6
|
+
frames: dict[int, dict]
|
|
7
|
+
team_assignments: dict[int, int]
|
|
8
|
+
total_frames: int
|
|
9
|
+
fps: float
|
|
10
|
+
|
|
11
|
+
def get_team_players(self, team: int) -> list[int]:
|
|
12
|
+
"""Get player IDs for a specific team (0 or 1)."""
|
|
13
|
+
return [tid for tid, t in self.team_assignments.items() if t == team]
|
|
14
|
+
|
|
15
|
+
def get_all_player_ids(self) -> list[int]:
|
|
16
|
+
"""Get all player IDs."""
|
|
17
|
+
return list(self.team_assignments.keys())
|
|
18
|
+
|
|
19
|
+
def get_player_trajectory(self, player_id: int) -> list[tuple[float, float]]:
|
|
20
|
+
"""Get all positions for one player."""
|
|
21
|
+
return [
|
|
22
|
+
self.frames[f]["players"][player_id]
|
|
23
|
+
for f in sorted(self.frames.keys())
|
|
24
|
+
if player_id in self.frames[f]["players"]
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
def get_ball_trajectory(self) -> list[tuple[float, float]]:
|
|
28
|
+
"""Get all ball positions."""
|
|
29
|
+
return [
|
|
30
|
+
self.frames[f]["ball"]
|
|
31
|
+
for f in sorted(self.frames.keys())
|
|
32
|
+
if self.frames[f]["ball"] is not None
|
|
33
|
+
]
|
wunderscout/exporters.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
import csv
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from .data import TrackingResult
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class DataExporter:
|
|
7
|
+
@staticmethod
|
|
8
|
+
def save_csvs(result: TrackingResult, output_path: str):
|
|
9
|
+
"""Export tracking data to CSV files (one per team)."""
|
|
10
|
+
path_obj = Path(output_path)
|
|
11
|
+
path_obj.parent.mkdir(parents=True, exist_ok=True)
|
|
12
|
+
base_name = str(path_obj.with_suffix(""))
|
|
13
|
+
|
|
14
|
+
team1_ids = result.get_team_players(0)
|
|
15
|
+
team2_ids = result.get_team_players(1)
|
|
16
|
+
|
|
17
|
+
def write_file(filename, team_name, ids):
|
|
18
|
+
with open(filename, "w", newline="") as f:
|
|
19
|
+
writer = csv.writer(f)
|
|
20
|
+
writer.writerow(
|
|
21
|
+
["", "", ""] + [team_name for _ in ids for _ in (0, 1)] + ["", ""]
|
|
22
|
+
)
|
|
23
|
+
writer.writerow(
|
|
24
|
+
["", "", ""] + [str(pid) for pid in ids for _ in (0, 1)] + ["", ""]
|
|
25
|
+
)
|
|
26
|
+
writer.writerow(
|
|
27
|
+
["Period", "Frame", "Time [s]"]
|
|
28
|
+
+ [f"Player{pid}_{axis}" for pid in ids for axis in ("X", "Y")]
|
|
29
|
+
+ ["Ball_X", "Ball_Y"]
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
for f_idx in range(result.total_frames):
|
|
33
|
+
data = result.frames.get(f_idx, {"ball": None, "players": {}})
|
|
34
|
+
row = [1, f_idx, f"{f_idx / result.fps:.2f}"]
|
|
35
|
+
for tid in ids:
|
|
36
|
+
coords = data["players"].get(tid, ("NaN", "NaN"))
|
|
37
|
+
row.extend(coords)
|
|
38
|
+
row.extend(data["ball"] if data["ball"] else ("NaN", "NaN"))
|
|
39
|
+
writer.writerow(row)
|
|
40
|
+
|
|
41
|
+
write_file(f"{base_name}_Team1.csv", "Team1", sorted(team1_ids))
|
|
42
|
+
write_file(f"{base_name}_Team2.csv", "Team2", sorted(team2_ids))
|
wunderscout/geometry.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
import cv2
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
PITCH_CONFIG = {
|
|
5
|
+
# --- LEFT GOAL LINE ---
|
|
6
|
+
0: (0.000, 0.000), # Top-Left Corner
|
|
7
|
+
1: (0.000, 0.204), # Top Edge of Penalty Box
|
|
8
|
+
2: (0.000, 0.365), # Top Edge of Goal Area
|
|
9
|
+
3: (0.000, 0.635), # Bottom Edge of Goal Area
|
|
10
|
+
4: (0.000, 0.796), # Bottom Edge of Penalty Box
|
|
11
|
+
5: (0.000, 1.000), # Bottom-Left Corner
|
|
12
|
+
# --- LEFT PENALTY AREA ---
|
|
13
|
+
6: (0.052, 0.365),
|
|
14
|
+
7: (0.052, 0.635),
|
|
15
|
+
8: (0.105, 0.500), # Penalty Spot (Left)
|
|
16
|
+
9: (0.157, 0.204),
|
|
17
|
+
10: (0.157, 0.392),
|
|
18
|
+
11: (0.157, 0.608),
|
|
19
|
+
12: (0.157, 0.796),
|
|
20
|
+
# --- MIDFIELD ---
|
|
21
|
+
13: (0.413, 0.500),
|
|
22
|
+
14: (0.500, 0.000),
|
|
23
|
+
15: (0.500, 0.365),
|
|
24
|
+
16: (0.500, 0.635),
|
|
25
|
+
17: (0.500, 1.000),
|
|
26
|
+
18: (0.587, 0.500),
|
|
27
|
+
# --- RIGHT PENALTY AREA ---
|
|
28
|
+
19: (0.843, 0.204),
|
|
29
|
+
20: (0.843, 0.392),
|
|
30
|
+
21: (0.843, 0.608),
|
|
31
|
+
22: (0.843, 0.796),
|
|
32
|
+
23: (0.895, 0.500), # Penalty Spot (Right)
|
|
33
|
+
24: (0.948, 0.365),
|
|
34
|
+
25: (0.948, 0.635),
|
|
35
|
+
# --- RIGHT GOAL LINE ---
|
|
36
|
+
26: (1.000, 0.000),
|
|
37
|
+
27: (1.000, 0.204),
|
|
38
|
+
28: (1.000, 0.365),
|
|
39
|
+
29: (1.000, 0.635),
|
|
40
|
+
30: (1.000, 0.796),
|
|
41
|
+
31: (1.000, 1.000),
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class PitchMapper:
|
|
46
|
+
def __init__(self, pitch_config=PITCH_CONFIG):
|
|
47
|
+
self.pitch_config = pitch_config
|
|
48
|
+
self.last_h = None
|
|
49
|
+
|
|
50
|
+
def get_matrix(self, keypoints_xy, keypoints_conf):
|
|
51
|
+
src_points = []
|
|
52
|
+
dst_points = []
|
|
53
|
+
|
|
54
|
+
for i, (xy, conf) in enumerate(zip(keypoints_xy, keypoints_conf)):
|
|
55
|
+
if conf > 0.5 and i in self.pitch_config:
|
|
56
|
+
src_points.append(xy)
|
|
57
|
+
dst_points.append(self.pitch_config[i])
|
|
58
|
+
|
|
59
|
+
if len(src_points) >= 4:
|
|
60
|
+
H, _ = cv2.findHomography(
|
|
61
|
+
np.array(src_points), np.array(dst_points), cv2.RANSAC
|
|
62
|
+
)
|
|
63
|
+
self.last_h = H
|
|
64
|
+
|
|
65
|
+
return self.last_h
|
|
66
|
+
|
|
67
|
+
def transform(self, points, H=None):
|
|
68
|
+
target_h = H if H is not None else self.last_h
|
|
69
|
+
if target_h is None or len(points) == 0:
|
|
70
|
+
return []
|
|
71
|
+
|
|
72
|
+
points_reshaped = np.array(points).reshape(-1, 1, 2).astype(np.float32)
|
|
73
|
+
projected = cv2.perspectiveTransform(points_reshaped, target_h)
|
|
74
|
+
return projected.reshape(-1, 2)
|
wunderscout/heatmap.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
import matplotlib.pyplot as plt
|
|
3
|
+
import numpy as np
|
|
4
|
+
import seaborn as sns
|
|
5
|
+
import json
|
|
6
|
+
from scipy.stats import gaussian_kde
|
|
7
|
+
|
|
8
|
+
# === Load the raw tracking data CSV ===
|
|
9
|
+
# NOTE: "header=2" → skip the first two rows (team/labels) and use row 3 as header
|
|
10
|
+
df = pd.read_csv(
|
|
11
|
+
"./data/Sample_Game_1_RawTrackingData_Away_Team.csv",
|
|
12
|
+
header=2,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
# === Clean column names so each player has _X and _Y ===
|
|
16
|
+
cleaned_colums = []
|
|
17
|
+
colnames = df.columns.tolist()
|
|
18
|
+
i = 0
|
|
19
|
+
while i < len(colnames):
|
|
20
|
+
col = colnames[i]
|
|
21
|
+
if col.startswith("Player") or col.startswith("Ball"):
|
|
22
|
+
cleaned_colums.append(f"{col}_X")
|
|
23
|
+
cleaned_colums.append(f"{col}_Y")
|
|
24
|
+
i += 2
|
|
25
|
+
else:
|
|
26
|
+
cleaned_colums.append(col)
|
|
27
|
+
i += 1
|
|
28
|
+
df.columns = cleaned_colums
|
|
29
|
+
|
|
30
|
+
print("Columns cleaned. First few rows:")
|
|
31
|
+
print(df.head())
|
|
32
|
+
|
|
33
|
+
# === Extract Player17 (drop NaN values where tracking failed) ===
|
|
34
|
+
player17 = df[["Player17_X", "Player17_Y"]].dropna()
|
|
35
|
+
x = player17["Player17_X"].to_numpy()
|
|
36
|
+
y = player17["Player17_Y"].to_numpy()
|
|
37
|
+
|
|
38
|
+
# === Detect scale (normalized [0,1] or real meters) ===
|
|
39
|
+
if x.max() <= 1.5 and y.max() <= 1.5:
|
|
40
|
+
print("Scaling Player17 data from normalized [0,1] to meters...")
|
|
41
|
+
x = x * 105 # pitch length in meters
|
|
42
|
+
y = y * 68 # pitch width in meters
|
|
43
|
+
else:
|
|
44
|
+
print("Data appears to already be in meters, leaving as is.")
|
|
45
|
+
|
|
46
|
+
print("First 10 points:", list(zip(x[:10], y[:10])))
|
|
47
|
+
|
|
48
|
+
# =============================================================================
|
|
49
|
+
# 1. Scatter Plot (sanity check, raw positions)
|
|
50
|
+
# =============================================================================
|
|
51
|
+
fig, ax = plt.subplots(figsize=(10, 7))
|
|
52
|
+
# Pitch outline
|
|
53
|
+
ax.plot([0, 105, 105, 0, 0], [0, 0, 68, 68, 0], color="black")
|
|
54
|
+
ax.plot([52.5, 52.5], [0, 68], color="black") # halfway line
|
|
55
|
+
# Player positions
|
|
56
|
+
ax.scatter(x, y, s=1, alpha=0.3, color="blue")
|
|
57
|
+
ax.set_xlim(0, 105)
|
|
58
|
+
ax.set_ylim(0, 68)
|
|
59
|
+
ax.set_title("Player17 Movement Scatter (raw positions)")
|
|
60
|
+
plt.savefig("./heatmap/player17_scatter.png", dpi=150, bbox_inches="tight")
|
|
61
|
+
|
|
62
|
+
# =============================================================================
|
|
63
|
+
# 2. Histogram Heatmap (occupancy grid)
|
|
64
|
+
# =============================================================================
|
|
65
|
+
heatmap, xedges, yedges = np.histogram2d(x, y, bins=(50, 34), range=[[0, 105], [0, 68]])
|
|
66
|
+
|
|
67
|
+
fig, ax = plt.subplots(figsize=(10, 7))
|
|
68
|
+
extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
|
|
69
|
+
im = ax.imshow(
|
|
70
|
+
heatmap.T, origin="lower", extent=extent, cmap="Blues", alpha=0.7, aspect="auto"
|
|
71
|
+
)
|
|
72
|
+
ax.plot([0, 105, 105, 0, 0], [0, 0, 68, 68, 0], color="black")
|
|
73
|
+
ax.plot([52.5, 52.5], [0, 68], color="black")
|
|
74
|
+
fig.colorbar(im, ax=ax, label="Frames")
|
|
75
|
+
ax.set_title("Player17 Heatmap (Histogram)")
|
|
76
|
+
plt.savefig("./heatmap/player17_histogram.png", dpi=150, bbox_inches="tight")
|
|
77
|
+
|
|
78
|
+
# === Export histogram data as JSON for three.js ===
|
|
79
|
+
heatmap_data = {
|
|
80
|
+
"xedges": xedges.tolist(),
|
|
81
|
+
"yedges": yedges.tolist(),
|
|
82
|
+
"values": heatmap.T.tolist(), # transpose so rows correspond to y-axis correctly
|
|
83
|
+
}
|
|
84
|
+
with open("./heatmap/player17_histogram.json", "w") as f:
|
|
85
|
+
json.dump(heatmap_data, f)
|
|
86
|
+
|
|
87
|
+
# =============================================================================
|
|
88
|
+
# 3. KDE Heatmap (smoothed density field)
|
|
89
|
+
# =============================================================================
|
|
90
|
+
values = np.vstack([x, y])
|
|
91
|
+
kde = gaussian_kde(values)
|
|
92
|
+
|
|
93
|
+
# Define mesh grid
|
|
94
|
+
X, Y = np.meshgrid(np.linspace(0, 105, 100), np.linspace(0, 68, 68))
|
|
95
|
+
Z = kde(np.vstack([X.ravel(), Y.ravel()])).reshape(X.shape)
|
|
96
|
+
|
|
97
|
+
fig, ax = plt.subplots(figsize=(10, 7))
|
|
98
|
+
sns.kdeplot(x=x, y=y, fill=True, cmap="Blues", alpha=0.7, thresh=0.05, levels=50, ax=ax)
|
|
99
|
+
ax.plot([0, 105, 105, 0, 0], [0, 0, 68, 68, 0], color="black")
|
|
100
|
+
ax.plot([52.5, 52.5], [0, 68], color="black")
|
|
101
|
+
ax.set_xlim(0, 105)
|
|
102
|
+
ax.set_ylim(0, 68)
|
|
103
|
+
ax.set_title("Player17 Heatmap (KDE Smoothed)")
|
|
104
|
+
plt.savefig("./heatmap/player17_kde.png", dpi=150, bbox_inches="tight")
|
|
105
|
+
|
|
106
|
+
# === Export KDE density field for three.js ===
|
|
107
|
+
kde_data = {
|
|
108
|
+
"x": X[0].tolist(), # x grid coordinates
|
|
109
|
+
"y": Y[:, 0].tolist(), # y grid coordinates
|
|
110
|
+
"values": Z.tolist(), # density values
|
|
111
|
+
}
|
|
112
|
+
with open("./heatmap/player17_kde.json", "w") as f:
|
|
113
|
+
json.dump(kde_data, f)
|
|
114
|
+
|
|
115
|
+
print("Outputs saved: scatter, histogram PNG+JSON, KDE PNG+JSON for Player17")
|
wunderscout/heatmaps.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import json
|
|
3
|
+
from scipy.stats import gaussian_kde
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Optional, Literal, Any
|
|
6
|
+
from .data import TrackingResult
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class HeatmapGenerator:
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
pitch_length: float = 105.0,
|
|
13
|
+
pitch_width: float = 68.0,
|
|
14
|
+
histogram_bins: tuple[int, int] = (50, 34),
|
|
15
|
+
kde_grid_size: tuple[int, int] = (100, 68),
|
|
16
|
+
min_samples_for_kde: int = 10, # Minimum samples needed for KDE
|
|
17
|
+
):
|
|
18
|
+
"""
|
|
19
|
+
Initialize heatmap generator with pitch dimensions and resolution.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
pitch_length: Length of pitch in meters (default 105m)
|
|
23
|
+
pitch_width: Width of pitch in meters (default 68m)
|
|
24
|
+
histogram_bins: (x_bins, y_bins) for histogram heatmap
|
|
25
|
+
kde_grid_size: (x_points, y_points) for KDE grid resolution
|
|
26
|
+
min_samples_for_kde: Minimum number of samples required for KDE
|
|
27
|
+
"""
|
|
28
|
+
self.pitch_length = pitch_length
|
|
29
|
+
self.pitch_width = pitch_width
|
|
30
|
+
self.histogram_bins = histogram_bins
|
|
31
|
+
self.kde_grid_size = kde_grid_size
|
|
32
|
+
self.min_samples_for_kde = min_samples_for_kde
|
|
33
|
+
|
|
34
|
+
def _scale_to_meters(self, positions: np.ndarray) -> np.ndarray:
|
|
35
|
+
"""Convert normalized [0, 1] coordinates to meters."""
|
|
36
|
+
scaled = positions.copy()
|
|
37
|
+
scaled[:, 0] *= self.pitch_length
|
|
38
|
+
scaled[:, 1] *= self.pitch_width
|
|
39
|
+
return scaled
|
|
40
|
+
|
|
41
|
+
def _has_sufficient_variation(self, x: np.ndarray, y: np.ndarray) -> bool:
|
|
42
|
+
"""Check if data has sufficient spatial variation for KDE."""
|
|
43
|
+
if len(x) < 2:
|
|
44
|
+
return False
|
|
45
|
+
|
|
46
|
+
# Check if all points are identical
|
|
47
|
+
x_range = np.ptp(x) # peak-to-peak (max - min)
|
|
48
|
+
y_range = np.ptp(y)
|
|
49
|
+
|
|
50
|
+
# Need at least some variation in both dimensions
|
|
51
|
+
# Using 1cm as minimum threshold
|
|
52
|
+
return x_range > 0.01 and y_range > 0.01
|
|
53
|
+
|
|
54
|
+
def generate_player_heatmap(
|
|
55
|
+
self,
|
|
56
|
+
result: TrackingResult,
|
|
57
|
+
player_id: int,
|
|
58
|
+
method: Literal["histogram", "kde", "both"] = "both",
|
|
59
|
+
) -> dict[str, Any]:
|
|
60
|
+
"""
|
|
61
|
+
Generate heatmap for a single player.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
result: TrackingResult from pipeline
|
|
65
|
+
player_id: Player tracker ID
|
|
66
|
+
method: "histogram", "kde", or "both"
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
Dictionary with heatmap data in format ready for JSON export
|
|
70
|
+
"""
|
|
71
|
+
trajectory = result.get_player_trajectory(player_id)
|
|
72
|
+
|
|
73
|
+
if len(trajectory) == 0:
|
|
74
|
+
raise ValueError(f"No trajectory data found for player {player_id}")
|
|
75
|
+
|
|
76
|
+
positions = np.array(trajectory)
|
|
77
|
+
positions_meters = self._scale_to_meters(positions)
|
|
78
|
+
|
|
79
|
+
x, y = positions_meters[:, 0], positions_meters[:, 1]
|
|
80
|
+
|
|
81
|
+
output: dict[str, Any] = {
|
|
82
|
+
"player_id": player_id,
|
|
83
|
+
"sample_count": len(trajectory),
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
# Always try histogram (works with any amount of data)
|
|
87
|
+
if method in ["histogram", "both"]:
|
|
88
|
+
try:
|
|
89
|
+
histogram_result = self._compute_histogram(x, y)
|
|
90
|
+
output["histogram"] = histogram_result
|
|
91
|
+
except Exception as e:
|
|
92
|
+
print(f"Warning: Histogram failed for player {player_id}: {e}")
|
|
93
|
+
# Don't include histogram key at all if it fails
|
|
94
|
+
|
|
95
|
+
# Only attempt KDE if we have enough quality data
|
|
96
|
+
if method in ["kde", "both"]:
|
|
97
|
+
if len(trajectory) < self.min_samples_for_kde:
|
|
98
|
+
print(
|
|
99
|
+
f"Info: Player {player_id} has only {len(trajectory)} samples "
|
|
100
|
+
f"(minimum {self.min_samples_for_kde} required for KDE). "
|
|
101
|
+
f"Skipping KDE, histogram only."
|
|
102
|
+
)
|
|
103
|
+
# Don't include kde key at all
|
|
104
|
+
elif not self._has_sufficient_variation(x, y):
|
|
105
|
+
print(
|
|
106
|
+
f"Info: Player {player_id} has insufficient spatial variation "
|
|
107
|
+
f"for KDE. Skipping KDE, histogram only."
|
|
108
|
+
)
|
|
109
|
+
# Don't include kde key at all
|
|
110
|
+
else:
|
|
111
|
+
try:
|
|
112
|
+
kde_result = self._compute_kde(x, y)
|
|
113
|
+
output["kde"] = kde_result
|
|
114
|
+
except Exception as e:
|
|
115
|
+
print(f"Warning: KDE failed for player {player_id}: {e}")
|
|
116
|
+
# Don't include kde key at all if it fails
|
|
117
|
+
|
|
118
|
+
return output
|
|
119
|
+
|
|
120
|
+
def _compute_histogram(self, x: np.ndarray, y: np.ndarray) -> dict[str, Any]:
|
|
121
|
+
"""Compute 2D histogram heatmap."""
|
|
122
|
+
heatmap, xedges, yedges = np.histogram2d(
|
|
123
|
+
x,
|
|
124
|
+
y,
|
|
125
|
+
bins=self.histogram_bins,
|
|
126
|
+
range=[[0, self.pitch_length], [0, self.pitch_width]],
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
return {
|
|
130
|
+
"xedges": xedges.tolist(),
|
|
131
|
+
"yedges": yedges.tolist(),
|
|
132
|
+
"values": heatmap.T.tolist(),
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
def _compute_kde(self, x: np.ndarray, y: np.ndarray) -> dict[str, Any]:
|
|
136
|
+
"""
|
|
137
|
+
Compute KDE smoothed density field.
|
|
138
|
+
|
|
139
|
+
Returns dict with:
|
|
140
|
+
- x: 1D list of x coordinates
|
|
141
|
+
- y: 1D list of y coordinates
|
|
142
|
+
- values: 2D list where values[i][j] = density at [x[j], y[i]]
|
|
143
|
+
"""
|
|
144
|
+
# Add small jitter to prevent perfect collinearity
|
|
145
|
+
# This helps with edge cases where points are nearly identical
|
|
146
|
+
jitter_amount = 0.01 # 1cm jitter
|
|
147
|
+
x_jittered = x + np.random.normal(0, jitter_amount, size=x.shape)
|
|
148
|
+
y_jittered = y + np.random.normal(0, jitter_amount, size=y.shape)
|
|
149
|
+
|
|
150
|
+
values = np.vstack([x_jittered, y_jittered])
|
|
151
|
+
kde = gaussian_kde(values)
|
|
152
|
+
|
|
153
|
+
# Create coordinate grids
|
|
154
|
+
x_coords = np.linspace(0, self.pitch_length, self.kde_grid_size[0])
|
|
155
|
+
y_coords = np.linspace(0, self.pitch_width, self.kde_grid_size[1])
|
|
156
|
+
X, Y = np.meshgrid(x_coords, y_coords)
|
|
157
|
+
|
|
158
|
+
# Evaluate KDE on grid
|
|
159
|
+
positions = np.vstack([X.ravel(), Y.ravel()])
|
|
160
|
+
Z = kde(positions).reshape(X.shape)
|
|
161
|
+
|
|
162
|
+
return {
|
|
163
|
+
"x": x_coords.tolist(),
|
|
164
|
+
"y": y_coords.tolist(),
|
|
165
|
+
"values": Z.tolist(),
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
def generate_team_heatmap(
|
|
169
|
+
self,
|
|
170
|
+
result: TrackingResult,
|
|
171
|
+
team: int,
|
|
172
|
+
method: Literal["histogram", "kde", "both"] = "both",
|
|
173
|
+
) -> dict[str, Any]:
|
|
174
|
+
"""
|
|
175
|
+
Generate aggregated heatmap for entire team.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
result: TrackingResult from pipeline
|
|
179
|
+
team: Team ID (0 or 1)
|
|
180
|
+
method: "histogram", "kde", or "both"
|
|
181
|
+
"""
|
|
182
|
+
player_ids = result.get_team_players(team)
|
|
183
|
+
|
|
184
|
+
if len(player_ids) == 0:
|
|
185
|
+
raise ValueError(f"No players found for team {team}")
|
|
186
|
+
|
|
187
|
+
# Collect all positions from all players
|
|
188
|
+
all_positions = []
|
|
189
|
+
for pid in player_ids:
|
|
190
|
+
trajectory = result.get_player_trajectory(pid)
|
|
191
|
+
all_positions.extend(trajectory)
|
|
192
|
+
|
|
193
|
+
if len(all_positions) == 0:
|
|
194
|
+
raise ValueError(f"No position data found for team {team}")
|
|
195
|
+
|
|
196
|
+
positions = np.array(all_positions)
|
|
197
|
+
positions_meters = self._scale_to_meters(positions)
|
|
198
|
+
x, y = positions_meters[:, 0], positions_meters[:, 1]
|
|
199
|
+
|
|
200
|
+
output: dict[str, Any] = {
|
|
201
|
+
"team_id": team,
|
|
202
|
+
"player_count": len(player_ids),
|
|
203
|
+
"sample_count": len(all_positions),
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
# Histogram (always attempt)
|
|
207
|
+
if method in ["histogram", "both"]:
|
|
208
|
+
try:
|
|
209
|
+
histogram_result = self._compute_histogram(x, y)
|
|
210
|
+
output["histogram"] = histogram_result
|
|
211
|
+
except Exception as e:
|
|
212
|
+
print(f"Warning: Team histogram failed for team {team}: {e}")
|
|
213
|
+
# Don't include histogram key at all if it fails
|
|
214
|
+
|
|
215
|
+
# KDE (with quality checks)
|
|
216
|
+
if method in ["kde", "both"]:
|
|
217
|
+
if len(all_positions) < self.min_samples_for_kde:
|
|
218
|
+
print(
|
|
219
|
+
f"Info: Team {team} has only {len(all_positions)} samples. "
|
|
220
|
+
f"Skipping KDE."
|
|
221
|
+
)
|
|
222
|
+
# Don't include kde key at all
|
|
223
|
+
elif not self._has_sufficient_variation(x, y):
|
|
224
|
+
print(f"Info: Team {team} has insufficient variation. Skipping KDE.")
|
|
225
|
+
# Don't include kde key at all
|
|
226
|
+
else:
|
|
227
|
+
try:
|
|
228
|
+
kde_result = self._compute_kde(x, y)
|
|
229
|
+
output["kde"] = kde_result
|
|
230
|
+
except Exception as e:
|
|
231
|
+
print(f"Warning: KDE failed for team {team}: {e}")
|
|
232
|
+
# Don't include kde key at all if it fails
|
|
233
|
+
|
|
234
|
+
return output
|
|
235
|
+
|
|
236
|
+
def generate_all_players_heatmaps(
|
|
237
|
+
self,
|
|
238
|
+
result: TrackingResult,
|
|
239
|
+
method: Literal["histogram", "kde", "both"] = "both",
|
|
240
|
+
) -> dict[int, dict[str, Any]]:
|
|
241
|
+
"""
|
|
242
|
+
Generate heatmaps for all players.
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
Dictionary mapping player_id -> heatmap data
|
|
246
|
+
"""
|
|
247
|
+
all_heatmaps = {}
|
|
248
|
+
|
|
249
|
+
for player_id in result.get_all_player_ids():
|
|
250
|
+
try:
|
|
251
|
+
all_heatmaps[player_id] = self.generate_player_heatmap(
|
|
252
|
+
result, player_id, method
|
|
253
|
+
)
|
|
254
|
+
except ValueError as e:
|
|
255
|
+
print(f"Warning: Skipping player {player_id}: {e}")
|
|
256
|
+
|
|
257
|
+
return all_heatmaps
|
|
258
|
+
|
|
259
|
+
def save_heatmap(
|
|
260
|
+
self,
|
|
261
|
+
heatmap_data: dict[str, Any],
|
|
262
|
+
output_path: str,
|
|
263
|
+
pretty: bool = False,
|
|
264
|
+
):
|
|
265
|
+
"""Save heatmap data to JSON file."""
|
|
266
|
+
path_obj = Path(output_path)
|
|
267
|
+
path_obj.parent.mkdir(parents=True, exist_ok=True)
|
|
268
|
+
|
|
269
|
+
with open(output_path, "w") as f:
|
|
270
|
+
json.dump(heatmap_data, f, indent=2 if pretty else None)
|
|
271
|
+
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import networkx as nx
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
import os
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
# Load match events JSON (replace with your actual file path)
|
|
9
|
+
with open(
|
|
10
|
+
"./data/3825818.json",
|
|
11
|
+
"r",
|
|
12
|
+
) as f:
|
|
13
|
+
events = json.load(f)
|
|
14
|
+
|
|
15
|
+
# Build mapping from player ID -> player name from the Starting XI event
|
|
16
|
+
player_id_to_name = {}
|
|
17
|
+
|
|
18
|
+
for ev in events:
|
|
19
|
+
if ev["type"]["name"] == "Starting XI":
|
|
20
|
+
for lineup in ev["tactics"]["lineup"]:
|
|
21
|
+
pid = lineup["player"]["id"]
|
|
22
|
+
name = lineup["player"]["name"]
|
|
23
|
+
player_id_to_name[pid] = name
|
|
24
|
+
|
|
25
|
+
# Data structures for passes and positions
|
|
26
|
+
edges = defaultdict(int) # (passer, recipient) -> count of passes
|
|
27
|
+
player_positions = defaultdict(list) # player_id -> list of [x, y] positions
|
|
28
|
+
|
|
29
|
+
TEAM_NAME = "Real Sociedad"
|
|
30
|
+
|
|
31
|
+
# Extract completed passes
|
|
32
|
+
for ev in events:
|
|
33
|
+
if ev["type"]["name"] == "Pass" and ev["team"]["name"] == TEAM_NAME:
|
|
34
|
+
passer = ev["player"]["id"]
|
|
35
|
+
recipient = ev.get("pass", {}).get("recipient", {}).get("id")
|
|
36
|
+
outcome = ev.get("pass", {}).get("outcome", {"name": "Complete"})["name"]
|
|
37
|
+
|
|
38
|
+
if outcome == "Complete" and recipient is not None:
|
|
39
|
+
edges[(passer, recipient)] += 1
|
|
40
|
+
start = ev["location"]
|
|
41
|
+
end = ev["pass"]["end_location"]
|
|
42
|
+
player_positions[passer].append(start)
|
|
43
|
+
player_positions[recipient].append(end)
|
|
44
|
+
|
|
45
|
+
# Calculate avg positions
|
|
46
|
+
avg_positions = {}
|
|
47
|
+
for player_id, coords in player_positions.items():
|
|
48
|
+
xs = [pt[0] for pt in coords]
|
|
49
|
+
ys = [pt[1] for pt in coords]
|
|
50
|
+
avg_positions[player_id] = [sum(xs) / len(xs), sum(ys) / len(ys)]
|
|
51
|
+
|
|
52
|
+
# Build a JSON-friendly structure for export (nodes + links)
|
|
53
|
+
nodes = [{"id": pid, "x": pos[0], "y": pos[1]} for pid, pos in avg_positions.items()]
|
|
54
|
+
links = [
|
|
55
|
+
{"source": src, "target": tgt, "value": count}
|
|
56
|
+
for (src, tgt), count in edges.items()
|
|
57
|
+
]
|
|
58
|
+
|
|
59
|
+
network = {"nodes": nodes, "links": links}
|
|
60
|
+
os.makedirs("pass_network", exist_ok=True)
|
|
61
|
+
with open("./pass_network/pass_network.json", "w") as f:
|
|
62
|
+
json.dump(network, f, indent=2)
|
|
63
|
+
|
|
64
|
+
# Build NetworkX graph
|
|
65
|
+
G = nx.DiGraph()
|
|
66
|
+
|
|
67
|
+
# Add nodes with positions
|
|
68
|
+
for pid, pos in avg_positions.items():
|
|
69
|
+
G.add_node(pid, pos=(pos[0], pos[1]))
|
|
70
|
+
|
|
71
|
+
# Add edges with weights
|
|
72
|
+
for (src, tgt), count in edges.items():
|
|
73
|
+
G.add_edge(src, tgt, weight=count)
|
|
74
|
+
|
|
75
|
+
# Draw graph
|
|
76
|
+
pos = nx.get_node_attributes(G, "pos")
|
|
77
|
+
labels = {pid: player_id_to_name.get(pid, str(pid)) for pid in G.nodes()}
|
|
78
|
+
|
|
79
|
+
fig, ax = plt.subplots(figsize=(10, 7))
|
|
80
|
+
|
|
81
|
+
# Draw pitch outline
|
|
82
|
+
ax.set_xlim(0, 120)
|
|
83
|
+
ax.set_ylim(0, 80)
|
|
84
|
+
ax.plot([0, 120, 120, 0, 0], [0, 0, 80, 80, 0], color="black")
|
|
85
|
+
|
|
86
|
+
# Draw nodes
|
|
87
|
+
nx.draw_networkx_nodes(G, pos, ax=ax, node_color="skyblue", node_size=500)
|
|
88
|
+
|
|
89
|
+
# Draw edges
|
|
90
|
+
nx.draw_networkx_edges(
|
|
91
|
+
G,
|
|
92
|
+
pos,
|
|
93
|
+
ax=ax,
|
|
94
|
+
width=[d["weight"] * 0.2 for _, _, d in G.edges(data=True)],
|
|
95
|
+
alpha=0.7,
|
|
96
|
+
arrowsize=10,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
# Draw player names
|
|
100
|
+
nx.draw_networkx_labels(G, pos, labels=labels, ax=ax, font_size=8)
|
|
101
|
+
|
|
102
|
+
plt.title("Team Pass Network")
|
|
103
|
+
plt.savefig("./pass_network/pass_network_viz.png", dpi=150, bbox_inches="tight")
|
wunderscout/teams.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import umap
|
|
3
|
+
from sklearn.cluster import KMeans
|
|
4
|
+
import supervision as sv
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class TeamClassifier:
|
|
8
|
+
def __init__(self):
|
|
9
|
+
self.reducer = umap.UMAP(n_components=3)
|
|
10
|
+
self.clusterer = KMeans(n_clusters=2, n_init=10, random_state=42)
|
|
11
|
+
self.history = {}
|
|
12
|
+
|
|
13
|
+
def fit(self, embeddings):
|
|
14
|
+
projections = self.reducer.fit_transform(embeddings)
|
|
15
|
+
self.clusterer.fit(projections)
|
|
16
|
+
|
|
17
|
+
def get_consensus_team(self, tracker_id, embedding):
|
|
18
|
+
proj = self.reducer.transform(embedding.reshape(1, -1))
|
|
19
|
+
pred = self.clusterer.predict(proj)[0]
|
|
20
|
+
|
|
21
|
+
if tracker_id not in self.history:
|
|
22
|
+
self.history[tracker_id] = []
|
|
23
|
+
self.history[tracker_id].append(pred)
|
|
24
|
+
if len(self.history[tracker_id]) > 50:
|
|
25
|
+
self.history[tracker_id].pop(0)
|
|
26
|
+
|
|
27
|
+
return (
|
|
28
|
+
1
|
|
29
|
+
if (sum(self.history[tracker_id]) / len(self.history[tracker_id])) > 0.5
|
|
30
|
+
else 0
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
def resolve_goalkeepers_team_id(self, players, goalkeepers):
|
|
34
|
+
"""
|
|
35
|
+
Assigns goalkeepers to the team whose centroid is closest.
|
|
36
|
+
players: sv.Detections (already classified with class_id 0 or 1)
|
|
37
|
+
goalkeepers: sv.Detections
|
|
38
|
+
"""
|
|
39
|
+
if len(players) == 0 or len(goalkeepers) == 0:
|
|
40
|
+
return np.array([0] * len(goalkeepers))
|
|
41
|
+
|
|
42
|
+
players_xy = players.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
|
|
43
|
+
goalkeepers_xy = goalkeepers.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
|
|
44
|
+
|
|
45
|
+
# Calculate centroids for Team 0 and Team 1
|
|
46
|
+
team_0_mask = players.class_id == 0
|
|
47
|
+
team_1_mask = players.class_id == 1
|
|
48
|
+
|
|
49
|
+
# Handle cases where one team might not be detected yet
|
|
50
|
+
if np.any(team_0_mask):
|
|
51
|
+
team_0_centroid = players_xy[team_0_mask].mean(axis=0)
|
|
52
|
+
else:
|
|
53
|
+
team_0_centroid = np.array([0, 0])
|
|
54
|
+
|
|
55
|
+
if np.any(team_1_mask):
|
|
56
|
+
team_1_centroid = players_xy[team_1_mask].mean(axis=0)
|
|
57
|
+
else:
|
|
58
|
+
team_1_centroid = np.array([10000, 10000]) # Far away
|
|
59
|
+
|
|
60
|
+
goalkeepers_team_id = []
|
|
61
|
+
|
|
62
|
+
for gk_xy in goalkeepers_xy:
|
|
63
|
+
dist_0 = np.linalg.norm(gk_xy - team_0_centroid)
|
|
64
|
+
dist_1 = np.linalg.norm(gk_xy - team_1_centroid)
|
|
65
|
+
goalkeepers_team_id.append(0 if dist_0 < dist_1 else 1)
|
|
66
|
+
|
|
67
|
+
return np.array(goalkeepers_team_id)
|
|
68
|
+
|
|
69
|
+
def get_final_assignments(self):
|
|
70
|
+
assignments = {}
|
|
71
|
+
for tid, votes in self.history.items():
|
|
72
|
+
if len(votes) > 0:
|
|
73
|
+
avg = sum(votes) / len(votes)
|
|
74
|
+
assignments[tid] = 1 if avg > 0.5 else 0
|
|
75
|
+
return assignments
|
|
76
|
+
|
wunderscout/vision.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from ultralytics import YOLO
|
|
3
|
+
import supervision as sv
|
|
4
|
+
from transformers import AutoProcessor, SiglipVisionModel, data
|
|
5
|
+
from roboflow import Roboflow
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
from more_itertools import chunked
|
|
8
|
+
import numpy as np
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class VisionEngine:
|
|
13
|
+
def __init__(self, player_weights, field_weights, device=None):
|
|
14
|
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
15
|
+
self.player_model = YOLO(player_weights)
|
|
16
|
+
self.field_model = YOLO(field_weights)
|
|
17
|
+
|
|
18
|
+
# Siglip for embeddings
|
|
19
|
+
siglip_path = "google/siglip-base-patch16-224"
|
|
20
|
+
self.siglip_model = SiglipVisionModel.from_pretrained(siglip_path).to(
|
|
21
|
+
self.device
|
|
22
|
+
)
|
|
23
|
+
self.siglip_processor = AutoProcessor.from_pretrained(siglip_path)
|
|
24
|
+
|
|
25
|
+
# --- Annotators ---
|
|
26
|
+
# Palette: 0=Blue, 1=Pink, 2=Yellow (Referee)
|
|
27
|
+
self.palette = sv.ColorPalette.from_hex(["#00BFFF", "#FF1493", "#FFD700"])
|
|
28
|
+
|
|
29
|
+
self.ellipse_annotator = sv.EllipseAnnotator(
|
|
30
|
+
color=self.palette,
|
|
31
|
+
thickness=2,
|
|
32
|
+
)
|
|
33
|
+
self.label_annotator = sv.LabelAnnotator(
|
|
34
|
+
color=self.palette,
|
|
35
|
+
text_color=sv.Color.from_hex("#000000"),
|
|
36
|
+
text_position=sv.Position.BOTTOM_CENTER,
|
|
37
|
+
)
|
|
38
|
+
self.triangle_annotator = sv.TriangleAnnotator(
|
|
39
|
+
color=sv.Color.from_hex("#FFD700"), base=25, height=21, outline_thickness=1
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
def get_calibration_crops(self, video_path, stride=30):
|
|
43
|
+
PLAYER_ID = 2
|
|
44
|
+
frame_generator = sv.get_video_frames_generator(
|
|
45
|
+
source_path=video_path, stride=stride
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
crops = []
|
|
49
|
+
for frame in frame_generator:
|
|
50
|
+
detections = self.detect_players(frame)
|
|
51
|
+
# Filter for players only for calibration
|
|
52
|
+
players = detections[detections.class_id == PLAYER_ID]
|
|
53
|
+
frame_crops = [sv.crop_image(frame, xyxy) for xyxy in players.xyxy]
|
|
54
|
+
crops += [sv.cv2_to_pillow(c) for c in frame_crops]
|
|
55
|
+
|
|
56
|
+
print(f"VisionEngine: Collected {len(crops)} calibration crops.")
|
|
57
|
+
return crops
|
|
58
|
+
|
|
59
|
+
def get_embeddings(self, pil_crops, batch_size=32):
|
|
60
|
+
batches = chunked(pil_crops, batch_size)
|
|
61
|
+
data_list = []
|
|
62
|
+
|
|
63
|
+
with torch.no_grad():
|
|
64
|
+
for batch in batches:
|
|
65
|
+
inputs = self.siglip_processor(images=batch, return_tensors="pt").to(
|
|
66
|
+
self.device
|
|
67
|
+
)
|
|
68
|
+
outputs = self.siglip_model(**inputs)
|
|
69
|
+
embeddings = torch.mean(outputs.last_hidden_state, dim=1).cpu().numpy()
|
|
70
|
+
data_list.append(embeddings)
|
|
71
|
+
|
|
72
|
+
return np.concatenate(data_list) if data_list else np.array([])
|
|
73
|
+
|
|
74
|
+
def detect_players(self, frame, conf=0.3):
|
|
75
|
+
result = self.player_model.predict(frame, conf=conf, verbose=False)[0]
|
|
76
|
+
return sv.Detections.from_ultralytics(result)
|
|
77
|
+
|
|
78
|
+
def detect_field(self, frame, conf=0.3):
|
|
79
|
+
result = self.field_model.predict(frame, conf=conf, verbose=False)[0]
|
|
80
|
+
return result
|
|
81
|
+
|
|
82
|
+
def draw_annotations(self, frame, all_detections, ball_detections):
|
|
83
|
+
annotated_frame = frame.copy()
|
|
84
|
+
|
|
85
|
+
# 1. Draw Ball
|
|
86
|
+
annotated_frame = self.triangle_annotator.annotate(
|
|
87
|
+
scene=annotated_frame, detections=ball_detections
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
# 2. Draw People (Players, GKs, Refs)
|
|
91
|
+
if len(all_detections) > 0:
|
|
92
|
+
# Ensure class_id is int for color mapping
|
|
93
|
+
all_detections.class_id = all_detections.class_id.astype(int)
|
|
94
|
+
|
|
95
|
+
labels = [f"#{tracker_id}" for tracker_id in all_detections.tracker_id]
|
|
96
|
+
|
|
97
|
+
annotated_frame = self.ellipse_annotator.annotate(
|
|
98
|
+
scene=annotated_frame, detections=all_detections
|
|
99
|
+
)
|
|
100
|
+
annotated_frame = self.label_annotator.annotate(
|
|
101
|
+
scene=annotated_frame, detections=all_detections, labels=labels
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
return annotated_frame
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class ScoutingTrainer:
|
|
108
|
+
def __init__(self, api_key):
|
|
109
|
+
self.rf = Roboflow(api_key=api_key)
|
|
110
|
+
|
|
111
|
+
def train_players(
|
|
112
|
+
self,
|
|
113
|
+
workspace,
|
|
114
|
+
project,
|
|
115
|
+
version,
|
|
116
|
+
epochs=300,
|
|
117
|
+
output_dir="../runs/training/player",
|
|
118
|
+
):
|
|
119
|
+
project = self.rf.workspace(workspace).project(project)
|
|
120
|
+
dataset = project.version(version).download("yolov11")
|
|
121
|
+
model = YOLO("../data/base_models/yolo11m.pt")
|
|
122
|
+
|
|
123
|
+
return model.train(
|
|
124
|
+
data=f"{dataset.location}/data.yaml",
|
|
125
|
+
epochs=epochs,
|
|
126
|
+
imgsz=1280,
|
|
127
|
+
plots=True,
|
|
128
|
+
device=0,
|
|
129
|
+
batch=2,
|
|
130
|
+
project=output_dir,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
def train_field(
|
|
134
|
+
self,
|
|
135
|
+
workspace,
|
|
136
|
+
project,
|
|
137
|
+
version,
|
|
138
|
+
epochs=300,
|
|
139
|
+
output_dir="../runs/training/field",
|
|
140
|
+
):
|
|
141
|
+
project = self.rf.workspace(workspace).project(project)
|
|
142
|
+
version = project.version(15)
|
|
143
|
+
dataset = version.download("yolov8", location="../data/data_sets/")
|
|
144
|
+
model = YOLO("yolo11m-pose.pt")
|
|
145
|
+
|
|
146
|
+
return model.train(
|
|
147
|
+
data=f"{dataset.location}/data.yaml",
|
|
148
|
+
save=True,
|
|
149
|
+
epochs=epochs,
|
|
150
|
+
plots=True,
|
|
151
|
+
imgsz=1080,
|
|
152
|
+
device=0,
|
|
153
|
+
batch=2,
|
|
154
|
+
project=output_dir,
|
|
155
|
+
)
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: wunderscout
|
|
3
|
+
Version: 0.1.11
|
|
4
|
+
Summary: Scouting and vision tools for YOLO and sports analytics.
|
|
5
|
+
Project-URL: Homepage, https://github.com/qhuboo/wunderscout
|
|
6
|
+
Project-URL: Issues, https://github.com/qhuboo/wunderscout/issues
|
|
7
|
+
Keywords: scouting,sports-analytics,vision,yolo
|
|
8
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
9
|
+
Classifier: Operating System :: OS Independent
|
|
10
|
+
Classifier: Programming Language :: Python :: 3
|
|
11
|
+
Requires-Python: >=3.11
|
|
12
|
+
Requires-Dist: matplotlib>=3.10.6
|
|
13
|
+
Requires-Dist: more-itertools>=10.8.0
|
|
14
|
+
Requires-Dist: networkx>=3.5
|
|
15
|
+
Requires-Dist: numpy>=2.3.2
|
|
16
|
+
Requires-Dist: opencv-python>=4.11.0.86
|
|
17
|
+
Requires-Dist: pandas>=2.3.2
|
|
18
|
+
Requires-Dist: python-dotenv>=1.1.1
|
|
19
|
+
Requires-Dist: roboflow>=1.2.7
|
|
20
|
+
Requires-Dist: scikit-learn>=1.7.2
|
|
21
|
+
Requires-Dist: scipy>=1.16.2
|
|
22
|
+
Requires-Dist: seaborn>=0.13.2
|
|
23
|
+
Requires-Dist: supervision>=0.26.1
|
|
24
|
+
Requires-Dist: torch>=2.8.0
|
|
25
|
+
Requires-Dist: tqdm>=4.67.1
|
|
26
|
+
Requires-Dist: transformers>=4.56.1
|
|
27
|
+
Requires-Dist: ultralytics>=8.3.193
|
|
28
|
+
Requires-Dist: umap-learn>=0.5.9.post2
|
|
29
|
+
Description-Content-Type: text/markdown
|
|
30
|
+
|
|
31
|
+
# wunderscout
|
|
32
|
+
|
|
33
|
+
A Python library for extracting player and ball tracking data from soccer match footage using YOLO, Siglip embeddings, and homography.
|
|
34
|
+
|
|
35
|
+
## Features
|
|
36
|
+
|
|
37
|
+
- **Detection & Tracking**: Uses YOLO for player/ball/pitch-keypoint detection and ByteTrack for temporal consistency.
|
|
38
|
+
- **Automated Team Clustering**: Groups players into teams using Siglip vision transformer embeddings and K-Means clustering via UMAP dimensionality reduction.
|
|
39
|
+
- **Pitch Mapping**: Transforms 2D image coordinates to a normalized 0-1 coordinate system using pitch keypoint homography.
|
|
40
|
+
- **Goalkeeper Attribution**: Assigns goalkeepers to teams based on proximity to team centroids.
|
|
41
|
+
- **Data Export**: Generates Home and Away CSV files containing frame-by-frame XY coordinates.
|
|
42
|
+
|
|
43
|
+
## Installation
|
|
44
|
+
|
|
45
|
+
```bash
|
|
46
|
+
uv add wunderscout
|
|
47
|
+
```
|
|
48
|
+
|
|
49
|
+
## Usage
|
|
50
|
+
|
|
51
|
+
The `ScoutingPipeline` class manages calibration, tracking, and data export.
|
|
52
|
+
|
|
53
|
+
```python
|
|
54
|
+
from wunderscout import ScoutingPipeline
|
|
55
|
+
|
|
56
|
+
# Initialize with paths to trained YOLO weights
|
|
57
|
+
pipeline = ScoutingPipeline(
|
|
58
|
+
player_weights="players.pt",
|
|
59
|
+
field_weights="pitch.pt"
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
# Run processing
|
|
63
|
+
pipeline.run(
|
|
64
|
+
video_path="input_match.mp4",
|
|
65
|
+
output_video_path="ouput_match.mp4"
|
|
66
|
+
)
|
|
67
|
+
```
|
|
68
|
+
|
|
69
|
+
## Internal Components
|
|
70
|
+
|
|
71
|
+
- **VisionEngine**: Manages YOLO models and generates Siglip embeddings for player crops.
|
|
72
|
+
- **PitchMapper**: Computes homography matrices based on a 32-point pitch configuration. Handles RANSAC-based perspective transforms.
|
|
73
|
+
- **TeamClassifier**: Performs unsupervised clustering on player embeddings. Uses a rolling consensus buffer to stabilize team assignments across frames.
|
|
74
|
+
- **DataExporter**: Formats tracking results into CSV files with frame indices and normalized pitch coordinates.
|
|
75
|
+
|
|
76
|
+
## Dependencies
|
|
77
|
+
|
|
78
|
+
- `ultralytics`
|
|
79
|
+
- `supervision`
|
|
80
|
+
- `transformers`
|
|
81
|
+
- `umap-learn`
|
|
82
|
+
- `scikit-learn`
|
|
83
|
+
- `opencv-python`
|
|
84
|
+
|
|
85
|
+
## License
|
|
86
|
+
|
|
87
|
+
MIT
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
wunderscout/__init__.py,sha256=QPNgcOVyGGefsTeA8LyHsfcPGc2Q4M28J6zU5RYhzr0,355
|
|
2
|
+
wunderscout/core.py,sha256=icdSLesum0cZV0MqIlwoxA2L1bfP5nnAGAjiLZy9bSg,6409
|
|
3
|
+
wunderscout/data.py,sha256=l6wRBd6WZGskqIguPjzFTfW47zWZYf8qPwzLVbsesyA,1069
|
|
4
|
+
wunderscout/exporters.py,sha256=FduXc4q065uF3Hp4G1LMO7xXVRC1DFewPBWnFHmS1bI,1756
|
|
5
|
+
wunderscout/geometry.py,sha256=I5lt00O9jOiEoVpPGy5iVglzA7cgaAdUvzfuFBcJbRA,2197
|
|
6
|
+
wunderscout/heatmap.py,sha256=5R8Zw5Bnk-8eHTcWudo-1a3Mt83WrSbqYX9MkUoht_8,4268
|
|
7
|
+
wunderscout/heatmaps.py,sha256=_c3pKkkvDMfdhHaxG5S9bqOSEdpyxcgozxYogB8xaak,9695
|
|
8
|
+
wunderscout/pass_network.py,sha256=QC859Pi5VKSgHu1qrE3Zybvu97lNorsb03UAf1IrSbs,3099
|
|
9
|
+
wunderscout/teams.py,sha256=y0IclDACo3F8buVdpqqMCSmZJeWx2uqMkGNbZ6YToVc,2628
|
|
10
|
+
wunderscout/vision.py,sha256=fVX3wtwCwe6AiiGxZjH8u4q2gk3t4gCb4MNvmQf7Lhs,5257
|
|
11
|
+
wunderscout-0.1.11.dist-info/METADATA,sha256=TT0uNp1KrZWIDpuib26X2Eqx6k-nTSPe2RN6wu0bMI0,2922
|
|
12
|
+
wunderscout-0.1.11.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
13
|
+
wunderscout-0.1.11.dist-info/RECORD,,
|