wunderscout 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,6 @@
1
+ from .vision import VisionEngine
2
+ from .geometry import PitchMapper
3
+ from .teams import TeamClassifier
4
+ from .core import ScoutingPipeline
5
+
6
+ __all__ = ["VisionEngine", "PitchMapper", "TeamClassifier", "ScoutingPipeline"]
wunderscout/core.py ADDED
@@ -0,0 +1,164 @@
1
+ import cv2
2
+ import supervision as sv
3
+ import numpy as np
4
+ from pathlib import Path
5
+ from .vision import VisionEngine
6
+ from .geometry import PitchMapper
7
+ from .teams import TeamClassifier
8
+ from .exporters import DataExporter
9
+
10
+
11
+ class ScoutingPipeline:
12
+ def __init__(self, player_weights, field_weights):
13
+ self.engine = VisionEngine(player_weights, field_weights)
14
+ self.mapper = PitchMapper()
15
+ self.classifier = TeamClassifier()
16
+
17
+ def run(self, video_path, output_video_path):
18
+ # 1. Warm-up (Calibration)
19
+ print("WORKER: Calibrating teams...")
20
+ crops = self.engine.get_calibration_crops(video_path)
21
+ if len(crops) > 0:
22
+ embeddings = self.engine.get_embeddings(crops)
23
+ self.classifier.fit(embeddings)
24
+ else:
25
+ print("WARNING: No player crops found for calibration.")
26
+
27
+ # 2. Setup Video I/O
28
+ output_path_obj = Path(output_video_path)
29
+ output_path_obj.parent.mkdir(parents=True, exist_ok=True)
30
+ # ---------------------------------------------------------------------
31
+
32
+ cap = cv2.VideoCapture(video_path)
33
+ fps = cap.get(cv2.CAP_PROP_FPS)
34
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
35
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
36
+
37
+ out = cv2.VideoWriter(
38
+ output_video_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height)
39
+ )
40
+
41
+ if not out.isOpened():
42
+ print(f"ERROR: Could not create video file at {output_video_path}")
43
+ return
44
+
45
+ tracker = sv.ByteTrack()
46
+ tracking_results = {}
47
+
48
+ # ID Constants
49
+ BALL_ID = 0
50
+ GOALKEEPER_ID = 1
51
+ PLAYER_ID = 2
52
+ REFEREE_ID = 3
53
+
54
+ # 3. Main Processing Loop
55
+ print(f"WORKER: Starting processing: {video_path}")
56
+ frame_generator = sv.get_video_frames_generator(video_path)
57
+
58
+ for frame_idx, frame in enumerate(frame_generator):
59
+ if frame_idx % 100 == 0:
60
+ print(f"WORKER: Processing frame {frame_idx}")
61
+
62
+ # --- A. DETECTION ---
63
+ all_dets = self.engine.detect_players(frame)
64
+ f_res = self.engine.detect_field(frame)
65
+
66
+ # --- B. FIELD HOMOGRAPHY ---
67
+ H = None
68
+ if f_res.keypoints is not None and len(f_res.keypoints.xy) > 0:
69
+ H = self.mapper.get_matrix(
70
+ f_res.keypoints.xy[0].cpu().numpy(),
71
+ f_res.keypoints.conf[0].cpu().numpy(),
72
+ )
73
+ else:
74
+ H = self.mapper.last_h
75
+
76
+ # --- C. SEPARATE BALL & OTHERS ---
77
+ ball_detections = all_dets[all_dets.class_id == BALL_ID]
78
+ ball_detections.xyxy = sv.pad_boxes(xyxy=ball_detections.xyxy, px=10)
79
+
80
+ other_detections = all_dets[all_dets.class_id != BALL_ID]
81
+ other_detections = other_detections.with_nms(threshold=0.5)
82
+
83
+ # --- D. TRACKING ---
84
+ tracked_objects = tracker.update_with_detections(other_detections)
85
+
86
+ # Split tracked objects
87
+ tracked_players = tracked_objects[tracked_objects.class_id == PLAYER_ID]
88
+ tracked_gks = tracked_objects[tracked_objects.class_id == GOALKEEPER_ID]
89
+ tracked_refs = tracked_objects[tracked_objects.class_id == REFEREE_ID]
90
+
91
+ # --- E. TEAM CLASSIFICATION ---
92
+
93
+ # 1. Players
94
+ if len(tracked_players) > 0:
95
+ p_crops = [sv.crop_image(frame, xyxy) for xyxy in tracked_players.xyxy]
96
+ p_pil = [sv.cv2_to_pillow(c) for c in p_crops]
97
+ p_embeddings = self.engine.get_embeddings(p_pil)
98
+
99
+ final_team_ids = []
100
+ for i, tid in enumerate(tracked_players.tracker_id):
101
+ team_id = self.classifier.get_consensus_team(tid, p_embeddings[i])
102
+ final_team_ids.append(team_id)
103
+
104
+ tracked_players.class_id = np.array(final_team_ids)
105
+
106
+ # 2. Goalkeepers
107
+ if len(tracked_gks) > 0 and len(tracked_players) > 0:
108
+ tracked_gks.class_id = self.classifier.resolve_goalkeepers_team_id(
109
+ tracked_players, tracked_gks
110
+ )
111
+
112
+ # 3. Referees (Shift ID 3 -> 2)
113
+ if len(tracked_refs) > 0:
114
+ tracked_refs.class_id -= 1
115
+
116
+ # --- F. DATA STORAGE ---
117
+ tracking_results[frame_idx] = {"players": {}, "ball": None}
118
+ data_targets = sv.Detections.merge([tracked_players, tracked_gks])
119
+
120
+ if H is not None:
121
+ if len(data_targets) > 0:
122
+ feet_coords = data_targets.get_anchors_coordinates(
123
+ sv.Position.BOTTOM_CENTER
124
+ )
125
+ transformed_feet = self.mapper.transform(feet_coords, H)
126
+
127
+ for i, tid in enumerate(data_targets.tracker_id):
128
+ px, py = transformed_feet[i]
129
+ tracking_results[frame_idx]["players"][tid] = (
130
+ max(0.0, min(1.0, px)),
131
+ max(0.0, min(1.0, py)),
132
+ )
133
+
134
+ if len(ball_detections) > 0:
135
+ ball_coords = ball_detections.get_anchors_coordinates(
136
+ sv.Position.CENTER
137
+ )
138
+ transformed_ball = self.mapper.transform([ball_coords[0]], H)
139
+ bx, by = transformed_ball[0]
140
+ tracking_results[frame_idx]["ball"] = (
141
+ max(0.0, min(1.0, bx)),
142
+ max(0.0, min(1.0, by)),
143
+ )
144
+
145
+ # --- G. DRAW & WRITE VIDEO ---
146
+ all_tracked = sv.Detections.merge(
147
+ [tracked_players, tracked_gks, tracked_refs]
148
+ )
149
+ annotated_frame = self.engine.draw_annotations(
150
+ frame, all_tracked, ball_detections
151
+ )
152
+ out.write(annotated_frame)
153
+
154
+ # 4. Cleanup
155
+ out.release()
156
+ cap.release()
157
+ print(f"WORKER: Video saved to {output_video_path}")
158
+
159
+ # Save CSVs
160
+ final_assignments = self.classifier.get_final_assignments()
161
+ csv_path = output_video_path.replace(".mp4", ".csv")
162
+ DataExporter.save_csvs(
163
+ tracking_results, final_assignments, frame_idx, fps, csv_path
164
+ )
@@ -0,0 +1,43 @@
1
+ import csv
2
+ from pathlib import Path
3
+
4
+
5
+ class DataExporter:
6
+ @staticmethod
7
+ def save_csvs(tracking_data, team_assignments, total_frames, fps, output_path):
8
+ """
9
+ tracking_data: {frame_idx: {"ball": (x,y), "players": {id: (x,y)}}}
10
+ team_assignments: {tracker_id: team_id}
11
+ """
12
+ path_obj = Path(output_path)
13
+ path_obj.parent.mkdir(parents=True, exist_ok=True)
14
+ base_name = str(path_obj.with_suffix(""))
15
+ home_ids = [tid for tid, team in team_assignments.items() if team == 0]
16
+ away_ids = [tid for tid, team in team_assignments.items() if team == 1]
17
+
18
+ def write_file(filename, team_name, ids):
19
+ with open(filename, "w", newline="") as f:
20
+ writer = csv.writer(f)
21
+ writer.writerow(
22
+ ["", "", ""] + [team_name for _ in ids for _ in (0, 1)] + ["", ""]
23
+ )
24
+ writer.writerow(
25
+ ["", "", ""] + [str(pid) for pid in ids for _ in (0, 1)] + ["", ""]
26
+ )
27
+ writer.writerow(
28
+ ["Period", "Frame", "Time [s]"]
29
+ + [f"Player{pid}_{axis}" for pid in ids for axis in ("X", "Y")]
30
+ + ["Ball_X", "Ball_Y"]
31
+ )
32
+
33
+ for f_idx in range(total_frames):
34
+ data = tracking_data.get(f_idx, {"ball": None, "players": {}})
35
+ row = [1, f_idx, f"{f_idx / fps:.2f}"]
36
+ for tid in ids:
37
+ coords = data["players"].get(tid, ("NaN", "NaN"))
38
+ row.extend(coords)
39
+ row.extend(data["ball"] if data["ball"] else ("NaN", "NaN"))
40
+ writer.writerow(row)
41
+
42
+ write_file(f"{base_name}_Home.csv", "Home", sorted(home_ids))
43
+ write_file(f"{base_name}_Away.csv", "Away", sorted(away_ids))
@@ -0,0 +1,74 @@
1
+ import cv2
2
+ import numpy as np
3
+
4
+ PITCH_CONFIG = {
5
+ # --- LEFT GOAL LINE ---
6
+ 0: (0.000, 0.000), # Top-Left Corner
7
+ 1: (0.000, 0.204), # Top Edge of Penalty Box
8
+ 2: (0.000, 0.365), # Top Edge of Goal Area
9
+ 3: (0.000, 0.635), # Bottom Edge of Goal Area
10
+ 4: (0.000, 0.796), # Bottom Edge of Penalty Box
11
+ 5: (0.000, 1.000), # Bottom-Left Corner
12
+ # --- LEFT PENALTY AREA ---
13
+ 6: (0.052, 0.365),
14
+ 7: (0.052, 0.635),
15
+ 8: (0.105, 0.500), # Penalty Spot (Left)
16
+ 9: (0.157, 0.204),
17
+ 10: (0.157, 0.392),
18
+ 11: (0.157, 0.608),
19
+ 12: (0.157, 0.796),
20
+ # --- MIDFIELD ---
21
+ 13: (0.413, 0.500),
22
+ 14: (0.500, 0.000),
23
+ 15: (0.500, 0.365),
24
+ 16: (0.500, 0.635),
25
+ 17: (0.500, 1.000),
26
+ 18: (0.587, 0.500),
27
+ # --- RIGHT PENALTY AREA ---
28
+ 19: (0.843, 0.204),
29
+ 20: (0.843, 0.392),
30
+ 21: (0.843, 0.608),
31
+ 22: (0.843, 0.796),
32
+ 23: (0.895, 0.500), # Penalty Spot (Right)
33
+ 24: (0.948, 0.365),
34
+ 25: (0.948, 0.635),
35
+ # --- RIGHT GOAL LINE ---
36
+ 26: (1.000, 0.000),
37
+ 27: (1.000, 0.204),
38
+ 28: (1.000, 0.365),
39
+ 29: (1.000, 0.635),
40
+ 30: (1.000, 0.796),
41
+ 31: (1.000, 1.000),
42
+ }
43
+
44
+
45
+ class PitchMapper:
46
+ def __init__(self, pitch_config=PITCH_CONFIG):
47
+ self.pitch_config = pitch_config
48
+ self.last_h = None
49
+
50
+ def get_matrix(self, keypoints_xy, keypoints_conf):
51
+ src_points = []
52
+ dst_points = []
53
+
54
+ for i, (xy, conf) in enumerate(zip(keypoints_xy, keypoints_conf)):
55
+ if conf > 0.5 and i in self.pitch_config:
56
+ src_points.append(xy)
57
+ dst_points.append(self.pitch_config[i])
58
+
59
+ if len(src_points) >= 4:
60
+ H, _ = cv2.findHomography(
61
+ np.array(src_points), np.array(dst_points), cv2.RANSAC
62
+ )
63
+ self.last_h = H
64
+
65
+ return self.last_h
66
+
67
+ def transform(self, points, H=None):
68
+ target_h = H if H is not None else self.last_h
69
+ if target_h is None or len(points) == 0:
70
+ return []
71
+
72
+ points_reshaped = np.array(points).reshape(-1, 1, 2).astype(np.float32)
73
+ projected = cv2.perspectiveTransform(points_reshaped, target_h)
74
+ return projected.reshape(-1, 2)
wunderscout/heatmap.py ADDED
@@ -0,0 +1,115 @@
1
+ import pandas as pd
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ import seaborn as sns
5
+ import json
6
+ from scipy.stats import gaussian_kde
7
+
8
+ # === Load the raw tracking data CSV ===
9
+ # NOTE: "header=2" → skip the first two rows (team/labels) and use row 3 as header
10
+ df = pd.read_csv(
11
+ "./data/Sample_Game_1_RawTrackingData_Away_Team.csv",
12
+ header=2,
13
+ )
14
+
15
+ # === Clean column names so each player has _X and _Y ===
16
+ cleaned_colums = []
17
+ colnames = df.columns.tolist()
18
+ i = 0
19
+ while i < len(colnames):
20
+ col = colnames[i]
21
+ if col.startswith("Player") or col.startswith("Ball"):
22
+ cleaned_colums.append(f"{col}_X")
23
+ cleaned_colums.append(f"{col}_Y")
24
+ i += 2
25
+ else:
26
+ cleaned_colums.append(col)
27
+ i += 1
28
+ df.columns = cleaned_colums
29
+
30
+ print("Columns cleaned. First few rows:")
31
+ print(df.head())
32
+
33
+ # === Extract Player17 (drop NaN values where tracking failed) ===
34
+ player17 = df[["Player17_X", "Player17_Y"]].dropna()
35
+ x = player17["Player17_X"].to_numpy()
36
+ y = player17["Player17_Y"].to_numpy()
37
+
38
+ # === Detect scale (normalized [0,1] or real meters) ===
39
+ if x.max() <= 1.5 and y.max() <= 1.5:
40
+ print("Scaling Player17 data from normalized [0,1] to meters...")
41
+ x = x * 105 # pitch length in meters
42
+ y = y * 68 # pitch width in meters
43
+ else:
44
+ print("Data appears to already be in meters, leaving as is.")
45
+
46
+ print("First 10 points:", list(zip(x[:10], y[:10])))
47
+
48
+ # =============================================================================
49
+ # 1. Scatter Plot (sanity check, raw positions)
50
+ # =============================================================================
51
+ fig, ax = plt.subplots(figsize=(10, 7))
52
+ # Pitch outline
53
+ ax.plot([0, 105, 105, 0, 0], [0, 0, 68, 68, 0], color="black")
54
+ ax.plot([52.5, 52.5], [0, 68], color="black") # halfway line
55
+ # Player positions
56
+ ax.scatter(x, y, s=1, alpha=0.3, color="blue")
57
+ ax.set_xlim(0, 105)
58
+ ax.set_ylim(0, 68)
59
+ ax.set_title("Player17 Movement Scatter (raw positions)")
60
+ plt.savefig("./heatmap/player17_scatter.png", dpi=150, bbox_inches="tight")
61
+
62
+ # =============================================================================
63
+ # 2. Histogram Heatmap (occupancy grid)
64
+ # =============================================================================
65
+ heatmap, xedges, yedges = np.histogram2d(x, y, bins=(50, 34), range=[[0, 105], [0, 68]])
66
+
67
+ fig, ax = plt.subplots(figsize=(10, 7))
68
+ extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
69
+ im = ax.imshow(
70
+ heatmap.T, origin="lower", extent=extent, cmap="Blues", alpha=0.7, aspect="auto"
71
+ )
72
+ ax.plot([0, 105, 105, 0, 0], [0, 0, 68, 68, 0], color="black")
73
+ ax.plot([52.5, 52.5], [0, 68], color="black")
74
+ fig.colorbar(im, ax=ax, label="Frames")
75
+ ax.set_title("Player17 Heatmap (Histogram)")
76
+ plt.savefig("./heatmap/player17_histogram.png", dpi=150, bbox_inches="tight")
77
+
78
+ # === Export histogram data as JSON for three.js ===
79
+ heatmap_data = {
80
+ "xedges": xedges.tolist(),
81
+ "yedges": yedges.tolist(),
82
+ "values": heatmap.T.tolist(), # transpose so rows correspond to y-axis correctly
83
+ }
84
+ with open("./heatmap/player17_histogram.json", "w") as f:
85
+ json.dump(heatmap_data, f)
86
+
87
+ # =============================================================================
88
+ # 3. KDE Heatmap (smoothed density field)
89
+ # =============================================================================
90
+ values = np.vstack([x, y])
91
+ kde = gaussian_kde(values)
92
+
93
+ # Define mesh grid
94
+ X, Y = np.meshgrid(np.linspace(0, 105, 100), np.linspace(0, 68, 68))
95
+ Z = kde(np.vstack([X.ravel(), Y.ravel()])).reshape(X.shape)
96
+
97
+ fig, ax = plt.subplots(figsize=(10, 7))
98
+ sns.kdeplot(x=x, y=y, fill=True, cmap="Blues", alpha=0.7, thresh=0.05, levels=50, ax=ax)
99
+ ax.plot([0, 105, 105, 0, 0], [0, 0, 68, 68, 0], color="black")
100
+ ax.plot([52.5, 52.5], [0, 68], color="black")
101
+ ax.set_xlim(0, 105)
102
+ ax.set_ylim(0, 68)
103
+ ax.set_title("Player17 Heatmap (KDE Smoothed)")
104
+ plt.savefig("./heatmap/player17_kde.png", dpi=150, bbox_inches="tight")
105
+
106
+ # === Export KDE density field for three.js ===
107
+ kde_data = {
108
+ "x": X[0].tolist(), # x grid coordinates
109
+ "y": Y[:, 0].tolist(), # y grid coordinates
110
+ "values": Z.tolist(), # density values
111
+ }
112
+ with open("./heatmap/player17_kde.json", "w") as f:
113
+ json.dump(kde_data, f)
114
+
115
+ print("Outputs saved: scatter, histogram PNG+JSON, KDE PNG+JSON for Player17")
wunderscout/main.py ADDED
@@ -0,0 +1,598 @@
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from roboflow import Roboflow
4
+ import torch
5
+ from ultralytics import YOLO
6
+ import supervision as sv
7
+ import cv2
8
+ from tqdm import tqdm
9
+ from transformers import AutoProcessor, SiglipVisionModel
10
+ from more_itertools import chunked
11
+ import numpy as np
12
+ import umap
13
+ from sklearn.cluster import KMeans
14
+ import plotly.graph_objects as go
15
+ import base64
16
+ from io import BytesIO
17
+
18
+ load_dotenv()
19
+
20
+ SIGLIP_MODEL_PATH = "google/siglip-base-patch16-224"
21
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
22
+ EMBEDDINGS_MODEL = SiglipVisionModel.from_pretrained(SIGLIP_MODEL_PATH).to(DEVICE)
23
+ EMBEDDINGS_PROCESSOR = AutoProcessor.from_pretrained(SIGLIP_MODEL_PATH)
24
+ ROBOFLOW_API_KEY = os.getenv("ROBOFLOW_API_KEY")
25
+
26
+ FIELD_DETECTION_MODEL_ID = "football-field-detection-f07vi/14"
27
+
28
+
29
+ def train():
30
+ print("Beginning training ...")
31
+ # This is to get the roboflow dataset and train.
32
+ rf = Roboflow(api_key=ROBOFLOW_API_KEY)
33
+
34
+ project = rf.workspace("roboflow-jvuqo").project("football-players-detection-3zvbc")
35
+ version = project.version(20)
36
+ dataset = version.download("yolov11")
37
+
38
+ model_base_s = YOLO("yolo11s.pt")
39
+
40
+ results = model_base_s.train(
41
+ data="./football-players-detection-20/data.yaml",
42
+ save=True,
43
+ epochs=50,
44
+ imgsz=1280,
45
+ plots=True,
46
+ device=0,
47
+ batch=6,
48
+ project="runs/detect",
49
+ )
50
+
51
+
52
+ def inference(video_path):
53
+ print("Inference 3 ...")
54
+ # We do another inference with different annotations
55
+ BALL_ID = 0
56
+
57
+ model_trained = YOLO("./runs/detect/train/weights/best.pt")
58
+ ellipse_annotator = sv.EllipseAnnotator(
59
+ color=sv.ColorPalette.from_hex(["#00BFFF", "#FF1493", "FFD700"]), thickness=2
60
+ )
61
+
62
+ triangle_annotator = sv.TriangleAnnotator(
63
+ color=sv.Color.from_hex("#FFD700"), base=25, height=21, outline_thickness=1
64
+ )
65
+
66
+ frame_generator_2 = sv.get_video_frames_generator(video_path)
67
+ frame_2 = next(frame_generator_2)
68
+
69
+ result_3 = model_trained.predict(
70
+ frame_2, save=True, project="runs", name="inference"
71
+ )[0]
72
+ detections_2 = sv.Detections.from_ultralytics(result_3)
73
+
74
+ ball_detections = detections_2[detections_2.class_id == BALL_ID]
75
+ ball_detections.xyxy = sv.pad_boxes(xyxy=ball_detections.xyxy, px=10)
76
+
77
+ all_detections = detections_2[detections_2.class_id != BALL_ID]
78
+ all_detections = all_detections.with_nms(threshold=0.5, class_agnostic=True)
79
+ all_detections.class_id -= 1
80
+
81
+ annotated_frame_2 = frame_2.copy()
82
+ annotated_frame_2 = ellipse_annotator.annotate(
83
+ scene=annotated_frame_2, detections=all_detections
84
+ )
85
+ annotated_frame_2 = triangle_annotator.annotate(
86
+ scene=annotated_frame_2, detections=ball_detections
87
+ )
88
+
89
+ cv2.imwrite("./runs/annotated_frame.jpg", annotated_frame_2)
90
+
91
+
92
+ def inference_with_player_tracking(video_path):
93
+ print("Inference with tracking ...")
94
+ # We will do inference again but this time with the video.
95
+ BALL_ID = 0
96
+
97
+ model_trained = YOLO("./runs/detect/train/weights/best.pt")
98
+ ellipse_annotator_2 = sv.EllipseAnnotator(
99
+ color=sv.ColorPalette.from_hex(["#00BFFF", "#FF1493", "#FFD700"]), thickness=2
100
+ )
101
+
102
+ label_annotator_2 = sv.LabelAnnotator(
103
+ color=sv.ColorPalette.from_hex(["#00BFFF", "#FF1493", "#FFD700"]),
104
+ text_color=sv.Color.from_hex("#000000"),
105
+ text_position=sv.Position.BOTTOM_CENTER,
106
+ )
107
+
108
+ triangle_annotator_2 = sv.TriangleAnnotator(
109
+ color=sv.Color.from_hex("#FFD700"),
110
+ base=25,
111
+ height=21,
112
+ outline_thickness=1,
113
+ )
114
+
115
+ tracker = sv.ByteTrack()
116
+
117
+ frame_generator_3 = sv.get_video_frames_generator(video_path)
118
+
119
+ cap = cv2.VideoCapture(video_path)
120
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
121
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
122
+ fps = cap.get(cv2.CAP_PROP_FPS)
123
+
124
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
125
+ out = cv2.VideoWriter(
126
+ "./runs/video_tracked.mp4", fourcc, fps, (frame_width, frame_height)
127
+ )
128
+
129
+ for frame in frame_generator_3:
130
+ result_tracking = model_trained.predict(frame, conf=0.3)[0]
131
+
132
+ detections = sv.Detections.from_ultralytics(result_tracking)
133
+
134
+ ball_detections = detections[detections.class_id == BALL_ID]
135
+ ball_detections.xyxy = sv.pad_boxes(xyxy=ball_detections.xyxy, px=10)
136
+
137
+ all_detections = detections[detections.class_id != BALL_ID]
138
+ all_detections = all_detections.with_nms(threshold=0.5, class_agnostic=True)
139
+ all_detections.class_id -= 1
140
+ all_detections = tracker.update_with_detections(detections=all_detections)
141
+
142
+ labels = [f"#{tracker_id}" for tracker_id in all_detections.tracker_id]
143
+
144
+ annotated_frame = frame.copy()
145
+ annotated_frame = ellipse_annotator_2.annotate(
146
+ scene=annotated_frame, detections=all_detections
147
+ )
148
+ annotated_frame = label_annotator_2.annotate(
149
+ scene=annotated_frame, detections=all_detections, labels=labels
150
+ )
151
+ annotated_frame = triangle_annotator_2.annotate(
152
+ scene=annotated_frame, detections=ball_detections
153
+ )
154
+
155
+ out.write(annotated_frame)
156
+ out.release()
157
+
158
+
159
+ def create_player_crops(video_path):
160
+ print("Generating player crops ...")
161
+ # Getting training data for cluster model
162
+ PLAYER_ID = 2
163
+ STRIDE = 30
164
+
165
+ model_trained = YOLO("./runs/detect/train/weights/best.pt")
166
+ frame_generator_4 = sv.get_video_frames_generator(
167
+ source_path=video_path, stride=STRIDE
168
+ )
169
+
170
+ crops = []
171
+ for frame in tqdm(frame_generator_4, desc="collecting crops"):
172
+ result = model_trained.predict(frame, conf=0.3)[0]
173
+ detections = sv.Detections.from_ultralytics(result)
174
+ detections = detections.with_nms(threshold=0.5, class_agnostic=True)
175
+ detections = detections[detections.class_id == PLAYER_ID]
176
+ players_crops = [sv.crop_image(frame, xyxy) for xyxy in detections.xyxy]
177
+ crops += players_crops
178
+
179
+ print(f"Total crops collected: {len(crops)}.")
180
+
181
+ # Save to a file
182
+ # out_dir = "./runs/player_crops"
183
+ # os.makedirs(out_dir, exist_ok=True)
184
+
185
+ # for i, crop in enumerate(crops):
186
+ # filename = os.path.join(out_dir, f"crop_{i:04d}.jpg")
187
+ # cv2.imwrite(filename, crop)
188
+
189
+ # Convert to PIL
190
+ pil_crops = [sv.cv2_to_pillow(c) for c in crops]
191
+ return pil_crops
192
+
193
+
194
+ def extract_embeddings(crops):
195
+ BATCH_SIZE = 32
196
+
197
+ batches = chunked(crops, BATCH_SIZE)
198
+ data = []
199
+
200
+ with torch.no_grad():
201
+ for batch in tqdm(batches, desc="embedding extraction"):
202
+ inputs = EMBEDDINGS_PROCESSOR(images=batch, return_tensors="pt").to(DEVICE)
203
+ outputs = EMBEDDINGS_MODEL(**inputs)
204
+ embeddings = torch.mean(outputs.last_hidden_state, dim=1).cpu().numpy()
205
+ data.append(embeddings)
206
+ data = np.concatenate(data)
207
+ return data
208
+
209
+
210
+ def cluster_players_by_team(embeddings):
211
+ REDUCER = umap.UMAP(n_components=3)
212
+ CLUSTERING_MODEL = KMeans(n_clusters=2, n_init=10, random_state=42)
213
+
214
+ projections = REDUCER.fit_transform(embeddings)
215
+ clusters = CLUSTERING_MODEL.fit_predict(projections)
216
+
217
+ return projections, clusters
218
+
219
+
220
+ def save_projection_plot_html(projections, clusters, crops):
221
+ # inline helper: convert a crop (PIL image) to base64 string
222
+ def pil_image_to_data_uri(image):
223
+ buffered = BytesIO()
224
+ image.save(buffered, format="PNG")
225
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
226
+ return f"data:image/png;base64,{img_str}"
227
+
228
+ # convert crops to base64-encoded URIs
229
+ image_data_uris = {
230
+ f"image_{i}": pil_image_to_data_uri(crop) for i, crop in enumerate(crops)
231
+ }
232
+ image_ids = np.array([f"image_{i}" for i in range(len(crops))])
233
+
234
+ traces = []
235
+ unique_clusters = np.unique(clusters)
236
+ for unique_cluster in unique_clusters:
237
+ mask = clusters == unique_cluster
238
+ trace = go.Scatter3d(
239
+ x=projections[mask][:, 0],
240
+ y=projections[mask][:, 1],
241
+ z=projections[mask][:, 2],
242
+ mode="markers+text", # hard-coded: markers+text
243
+ text=clusters[mask],
244
+ customdata=image_ids[mask],
245
+ name=str(unique_cluster),
246
+ marker=dict(size=6),
247
+ hovertemplate="<b>class: %{text}</b><br>image ID: %{customdata}<extra></extra>",
248
+ )
249
+ traces.append(trace)
250
+
251
+ # make cube axis range
252
+ min_val = np.min(projections)
253
+ max_val = np.max(projections)
254
+ padding = (max_val - min_val) * 0.05
255
+ axis_range = [min_val - padding, max_val + padding]
256
+
257
+ fig = go.Figure(data=traces)
258
+ fig.update_layout(
259
+ scene=dict(
260
+ xaxis=dict(title="X", range=axis_range),
261
+ yaxis=dict(title="Y", range=axis_range),
262
+ zaxis=dict(title="Z", range=axis_range),
263
+ aspectmode="cube",
264
+ ),
265
+ width=1000,
266
+ height=1000,
267
+ showlegend=True,
268
+ )
269
+
270
+ # embed chart HTML with custom JS for crop preview
271
+ plotly_div = fig.to_html(
272
+ full_html=False, include_plotlyjs=True, div_id="scatter-plot-3d"
273
+ )
274
+ javascript_code = f"""
275
+ <script>
276
+ function displayImage(imageId) {{
277
+ var imageElement = document.getElementById('image-display');
278
+ var placeholderText = document.getElementById('placeholder-text');
279
+ var imageDataURIs = {image_data_uris};
280
+ imageElement.src = imageDataURIs[imageId];
281
+ imageElement.style.display = 'block';
282
+ placeholderText.style.display = 'none';
283
+ }}
284
+
285
+ var chartElement = document.getElementById('scatter-plot-3d');
286
+ chartElement.on('plotly_click', function(data) {{
287
+ var customdata = data.points[0].customdata;
288
+ displayImage(customdata);
289
+ }});
290
+ </script>
291
+ """
292
+
293
+ html_template = f"""
294
+ <!DOCTYPE html>
295
+ <html>
296
+ <head>
297
+ <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
298
+ <style>
299
+ #image-container {{
300
+ position: fixed;
301
+ top: 0;
302
+ left: 0;
303
+ width: 200px;
304
+ height: 200px;
305
+ padding: 5px;
306
+ border: 1px solid #ccc;
307
+ background-color: white;
308
+ z-index: 1000;
309
+ box-sizing: border-box;
310
+ display: flex;
311
+ align-items: center;
312
+ justify-content: center;
313
+ text-align: center;
314
+ }}
315
+ #image-display {{
316
+ width: 100%;
317
+ height: 100%;
318
+ object-fit: contain;
319
+ }}
320
+ </style>
321
+ </head>
322
+ <body>
323
+ {plotly_div}
324
+ <div id="image-container">
325
+ <img id="image-display" src="" alt="Selected image" style="display: none;" />
326
+ <p id="placeholder-text">Click a data point to display the image</p>
327
+ </div>
328
+ {javascript_code}
329
+ </body>
330
+ </html>
331
+ """
332
+
333
+ out_path = "./runs/player_clusters.html"
334
+ with open(out_path, "w") as f:
335
+ f.write(html_template)
336
+
337
+ print(f"Interactive plot saved to {out_path}. Open it in your browser.")
338
+
339
+
340
+ def resolve_goalkeepers_team_id(players, goalkeepers):
341
+ players_xy = players.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
342
+ goalkeepers_xy = goalkeepers.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
343
+
344
+ team_0_centroid = players_xy[players.class_id == 0].mean(axis=0)
345
+ team_1_centroid = players_xy[players.class_id == 1].mean(axis=0)
346
+ goalkeepers_team_id = []
347
+
348
+ for gk_xy in goalkeepers_xy:
349
+ dist_0 = np.linalg.norm(gk_xy - team_0_centroid)
350
+ dist_1 = np.linalg.norm(gk_xy - team_1_centroid)
351
+ goalkeepers_team_id.append(0 if dist_0 < dist_1 else 1)
352
+
353
+ return np.array(goalkeepers_team_id)
354
+
355
+
356
+ def inference_with_goalkeepers(video_path):
357
+ BALL_ID = 0
358
+ GOALKEEPER_ID = 1
359
+ PLAYER_ID = 2
360
+ REFEREE_ID = 3
361
+
362
+ crops = create_player_crops(video_path)
363
+ embeddings = extract_embeddings(crops)
364
+
365
+ REDUCER = umap.UMAP(n_components=3)
366
+ CLUSTERING_MODEL = KMeans(n_clusters=2, n_init=10, random_state=42)
367
+
368
+ projections = REDUCER.fit_transform(embeddings)
369
+ clustering_model = CLUSTERING_MODEL.fit(projections)
370
+
371
+ model_trained = YOLO("./runs/detect/train/weights/best.pt")
372
+
373
+ frame_generator = sv.get_video_frames_generator(video_path)
374
+ frame = next(frame_generator)
375
+
376
+ tracker = sv.ByteTrack()
377
+
378
+ cap = cv2.VideoCapture(video_path)
379
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
380
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
381
+ fps = cap.get(cv2.CAP_PROP_FPS)
382
+
383
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
384
+ out = cv2.VideoWriter(
385
+ "./runs/video_tracked_2.mp4", fourcc, fps, (frame_width, frame_height)
386
+ )
387
+
388
+ for frame in frame_generator:
389
+ result = model_trained.predict(frame, conf=0.3)[0]
390
+
391
+ detections = sv.Detections.from_ultralytics(result)
392
+
393
+ ball_detections = detections[detections.class_id == BALL_ID]
394
+ ball_detections.xyxy = sv.pad_boxes(xyxy=ball_detections.xyxy, px=10)
395
+
396
+ all_detections = detections[detections.class_id != BALL_ID]
397
+ all_detections = all_detections.with_nms(threshold=0.5, class_agnostic=True)
398
+ all_detections = tracker.update_with_detections(detections=all_detections)
399
+
400
+ goalkeepers_detections = all_detections[
401
+ all_detections.class_id == GOALKEEPER_ID
402
+ ]
403
+ players_detections = all_detections[all_detections.class_id == PLAYER_ID]
404
+ referees_detections = all_detections[all_detections.class_id == REFEREE_ID]
405
+
406
+ players_crops = [sv.crop_image(frame, xyxy) for xyxy in players_detections.xyxy]
407
+ player_embeddings = extract_embeddings(players_crops)
408
+ player_projection = REDUCER.transform(player_embeddings)
409
+ players_detections.class_id = clustering_model.predict(player_projection)
410
+
411
+ goalkeepers_detections.class_id = resolve_goalkeepers_team_id(
412
+ players_detections, goalkeepers_detections
413
+ )
414
+
415
+ referees_detections.class_id -= 1
416
+
417
+ all_detections = sv.Detections.merge(
418
+ [players_detections, goalkeepers_detections, referees_detections]
419
+ )
420
+
421
+ labels = [f"#{tracker_id}" for tracker_id in all_detections.tracker_id]
422
+
423
+ all_detections.class_id = all_detections.class_id.astype(int)
424
+
425
+ ellipse_annotator = sv.EllipseAnnotator(
426
+ color=sv.ColorPalette.from_hex(["#00BFFF", "#FF1493", "#FFD700"]),
427
+ thickness=2,
428
+ )
429
+ label_annotator = sv.LabelAnnotator(
430
+ color=sv.ColorPalette.from_hex(["#00BFFF", "#FF1493", "#FFD700"]),
431
+ text_color=sv.Color.from_hex("#000000"),
432
+ text_position=sv.Position.BOTTOM_CENTER,
433
+ )
434
+ triangle_annotator = sv.TriangleAnnotator(
435
+ color=sv.Color.from_hex("#FFD700"), base=25, height=21, outline_thickness=1
436
+ )
437
+
438
+ annotated_frame = frame.copy()
439
+ annotated_frame = ellipse_annotator.annotate(
440
+ scene=annotated_frame, detections=all_detections
441
+ )
442
+ annotated_frame = label_annotator.annotate(
443
+ scene=annotated_frame, detections=all_detections, labels=labels
444
+ )
445
+ annotated_frame = triangle_annotator.annotate(
446
+ scene=annotated_frame, detections=ball_detections
447
+ )
448
+
449
+ out.write(annotated_frame)
450
+ out.release()
451
+
452
+
453
+ def keypoint_detection_train():
454
+ print("Beginning training ...")
455
+ # This is to get the roboflow dataset and train.
456
+ rf = Roboflow(api_key=ROBOFLOW_API_KEY)
457
+ project = rf.workspace("roboflow-jvuqo").project("football-field-detection-f07vi")
458
+ version = project.version(15)
459
+ dataset = version.download("yolov8")
460
+
461
+ model_base_s = YOLO("yolo11m-pose.pt")
462
+ results = model_base_s.train(
463
+ data="./football-field-detection-15/data.yaml",
464
+ save=True,
465
+ epochs=300,
466
+ plots=True,
467
+ imgsz=1080,
468
+ device=0,
469
+ batch=2,
470
+ project="runs/keypoint_detect/",
471
+ )
472
+
473
+
474
+ def keypoint_detection_inference(video_path: str) -> None:
475
+ model = YOLO("./runs/keypoint_detect/train4/weights/best.pt")
476
+ vertex_annotator = sv.VertexAnnotator(
477
+ color=sv.Color.from_hex("#FF1493"),
478
+ radius=8,
479
+ )
480
+
481
+ frame = next(sv.get_video_frames_generator(video_path))
482
+ res = model.predict(frame, conf=0.3, iou=0.7, imgsz=1088, verbose=False)[0]
483
+
484
+ if res.boxes is None or len(res.boxes) == 0:
485
+ cv2.imwrite("./runs/annotated_frame_keypoint.jpg", frame)
486
+ return
487
+
488
+ best_i = int(np.argmax(res.boxes.conf.cpu().numpy()))
489
+
490
+ kxy = res.keypoints.xy[best_i].cpu().numpy()
491
+ kconf = res.keypoints.conf[best_i].cpu().numpy()
492
+
493
+ keep = kconf > 0.5
494
+ kxy = kxy[keep]
495
+
496
+ kp = sv.KeyPoints(xy=kxy[np.newaxis, ...])
497
+
498
+ out = vertex_annotator.annotate(scene=frame.copy(), key_points=kp)
499
+ cv2.imwrite("./runs/annotated_frame_keypoint.jpg", out)
500
+
501
+
502
+ def keypoint_detection_indices(video_path, output_path="keypoint_debug.mp4"):
503
+ print(f"DEBUG: Starting keypoint visualization for {video_path}")
504
+
505
+ model_trained = YOLO("./runs/keypoint_detect/train4/weights/best.pt")
506
+
507
+ cap = cv2.VideoCapture(video_path)
508
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
509
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
510
+ fps = cap.get(cv2.CAP_PROP_FPS)
511
+
512
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
513
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
514
+
515
+ frame_count = 0
516
+
517
+ while True:
518
+ ret, frame = cap.read()
519
+ if not ret:
520
+ break
521
+
522
+ frame_count += 1
523
+ if frame_count % 30 == 0:
524
+ print(f"Processing frame {frame_count}...")
525
+
526
+ result = model_trained.predict(frame, conf=0.5, verbose=False)[0]
527
+
528
+ if result.keypoints is not None and result.keypoints.xy.shape[0] > 0:
529
+ kpts_xy = result.keypoints.xy[0].cpu().numpy()
530
+ kpts_conf = result.keypoints.conf[0].cpu().numpy()
531
+
532
+ for i, (x, y) in enumerate(kpts_xy):
533
+ conf = kpts_conf[i]
534
+
535
+ if conf < 0.5 or (x == 0 and y == 0):
536
+ continue
537
+
538
+ x, y = int(x), int(y)
539
+ cv2.circle(frame, (x, y), 5, (0, 0, 255), -1)
540
+ cv2.putText(
541
+ frame,
542
+ str(i + 1),
543
+ (x + 10, y),
544
+ cv2.FONT_HERSHEY_SIMPLEX,
545
+ 0.6,
546
+ (0, 255, 255),
547
+ 2,
548
+ )
549
+ else:
550
+ cv2.putText(
551
+ frame,
552
+ "NO PITCH DETECTED",
553
+ (50, 50),
554
+ cv2.FONT_HERSHEY_SIMPLEX,
555
+ 1,
556
+ (0, 0, 255),
557
+ 2,
558
+ )
559
+ out.write(frame)
560
+
561
+ cap.release()
562
+ out.release()
563
+ print("Video generation complete.")
564
+
565
+
566
+ def keypoint_detection_val_test(model_path, data_yaml):
567
+ model = YOLO(model_path)
568
+ results = model.val(data=data_yaml, split="test", plots=True)
569
+
570
+ print(results)
571
+
572
+
573
+ def main():
574
+ print("CUDA available:", torch.cuda.is_available())
575
+ if torch.cuda.is_available():
576
+ print("GPU name:", torch.cuda.get_device_name(0))
577
+
578
+ video_path = "/home/lucas/Documents/dev/local/yolo_torch/video.mp4"
579
+
580
+ # train()
581
+ # inference(video_path)
582
+ # inference_with_player_tracking(video_path)
583
+ # crops = create_player_crops(video_path)
584
+ # embeddings = extract_embeddings(crops)
585
+ # projections, clusters = cluster_players_by_team(embeddings)
586
+ # save_projection_plot_html(projections, clusters, crops)
587
+ # inference_with_goalkeepers(video_path)
588
+ # keypoint_detection_train()
589
+ # keypoint_detection_inference(video_path)
590
+ keypoint_detection_indices(video_path)
591
+ # val_test(
592
+ # "./runs/keypoint_detect/train4/weights/best.pt",
593
+ # "./football-field-detection-15/data.yaml",
594
+ # )
595
+
596
+
597
+ if __name__ == "__main__":
598
+ main()
@@ -0,0 +1,103 @@
1
+ import json
2
+ import networkx as nx
3
+ import matplotlib.pyplot as plt
4
+ from collections import defaultdict
5
+ import os
6
+ from pathlib import Path
7
+
8
+ # Load match events JSON (replace with your actual file path)
9
+ with open(
10
+ "./data/3825818.json",
11
+ "r",
12
+ ) as f:
13
+ events = json.load(f)
14
+
15
+ # Build mapping from player ID -> player name from the Starting XI event
16
+ player_id_to_name = {}
17
+
18
+ for ev in events:
19
+ if ev["type"]["name"] == "Starting XI":
20
+ for lineup in ev["tactics"]["lineup"]:
21
+ pid = lineup["player"]["id"]
22
+ name = lineup["player"]["name"]
23
+ player_id_to_name[pid] = name
24
+
25
+ # Data structures for passes and positions
26
+ edges = defaultdict(int) # (passer, recipient) -> count of passes
27
+ player_positions = defaultdict(list) # player_id -> list of [x, y] positions
28
+
29
+ TEAM_NAME = "Real Sociedad"
30
+
31
+ # Extract completed passes
32
+ for ev in events:
33
+ if ev["type"]["name"] == "Pass" and ev["team"]["name"] == TEAM_NAME:
34
+ passer = ev["player"]["id"]
35
+ recipient = ev.get("pass", {}).get("recipient", {}).get("id")
36
+ outcome = ev.get("pass", {}).get("outcome", {"name": "Complete"})["name"]
37
+
38
+ if outcome == "Complete" and recipient is not None:
39
+ edges[(passer, recipient)] += 1
40
+ start = ev["location"]
41
+ end = ev["pass"]["end_location"]
42
+ player_positions[passer].append(start)
43
+ player_positions[recipient].append(end)
44
+
45
+ # Calculate avg positions
46
+ avg_positions = {}
47
+ for player_id, coords in player_positions.items():
48
+ xs = [pt[0] for pt in coords]
49
+ ys = [pt[1] for pt in coords]
50
+ avg_positions[player_id] = [sum(xs) / len(xs), sum(ys) / len(ys)]
51
+
52
+ # Build a JSON-friendly structure for export (nodes + links)
53
+ nodes = [{"id": pid, "x": pos[0], "y": pos[1]} for pid, pos in avg_positions.items()]
54
+ links = [
55
+ {"source": src, "target": tgt, "value": count}
56
+ for (src, tgt), count in edges.items()
57
+ ]
58
+
59
+ network = {"nodes": nodes, "links": links}
60
+ os.makedirs("pass_network", exist_ok=True)
61
+ with open("./pass_network/pass_network.json", "w") as f:
62
+ json.dump(network, f, indent=2)
63
+
64
+ # Build NetworkX graph
65
+ G = nx.DiGraph()
66
+
67
+ # Add nodes with positions
68
+ for pid, pos in avg_positions.items():
69
+ G.add_node(pid, pos=(pos[0], pos[1]))
70
+
71
+ # Add edges with weights
72
+ for (src, tgt), count in edges.items():
73
+ G.add_edge(src, tgt, weight=count)
74
+
75
+ # Draw graph
76
+ pos = nx.get_node_attributes(G, "pos")
77
+ labels = {pid: player_id_to_name.get(pid, str(pid)) for pid in G.nodes()}
78
+
79
+ fig, ax = plt.subplots(figsize=(10, 7))
80
+
81
+ # Draw pitch outline
82
+ ax.set_xlim(0, 120)
83
+ ax.set_ylim(0, 80)
84
+ ax.plot([0, 120, 120, 0, 0], [0, 0, 80, 80, 0], color="black")
85
+
86
+ # Draw nodes
87
+ nx.draw_networkx_nodes(G, pos, ax=ax, node_color="skyblue", node_size=500)
88
+
89
+ # Draw edges
90
+ nx.draw_networkx_edges(
91
+ G,
92
+ pos,
93
+ ax=ax,
94
+ width=[d["weight"] * 0.2 for _, _, d in G.edges(data=True)],
95
+ alpha=0.7,
96
+ arrowsize=10,
97
+ )
98
+
99
+ # Draw player names
100
+ nx.draw_networkx_labels(G, pos, labels=labels, ax=ax, font_size=8)
101
+
102
+ plt.title("Team Pass Network")
103
+ plt.savefig("./pass_network/pass_network_viz.png", dpi=150, bbox_inches="tight")
wunderscout/teams.py ADDED
@@ -0,0 +1,76 @@
1
+ import numpy as np
2
+ import umap
3
+ from sklearn.cluster import KMeans
4
+ import supervision as sv
5
+
6
+
7
+ class TeamClassifier:
8
+ def __init__(self):
9
+ self.reducer = umap.UMAP(n_components=3)
10
+ self.clusterer = KMeans(n_clusters=2, n_init=10, random_state=42)
11
+ self.history = {}
12
+
13
+ def fit(self, embeddings):
14
+ projections = self.reducer.fit_transform(embeddings)
15
+ self.clusterer.fit(projections)
16
+
17
+ def get_consensus_team(self, tracker_id, embedding):
18
+ proj = self.reducer.transform(embedding.reshape(1, -1))
19
+ pred = self.clusterer.predict(proj)[0]
20
+
21
+ if tracker_id not in self.history:
22
+ self.history[tracker_id] = []
23
+ self.history[tracker_id].append(pred)
24
+ if len(self.history[tracker_id]) > 50:
25
+ self.history[tracker_id].pop(0)
26
+
27
+ return (
28
+ 1
29
+ if (sum(self.history[tracker_id]) / len(self.history[tracker_id])) > 0.5
30
+ else 0
31
+ )
32
+
33
+ def resolve_goalkeepers_team_id(self, players, goalkeepers):
34
+ """
35
+ Assigns goalkeepers to the team whose centroid is closest.
36
+ players: sv.Detections (already classified with class_id 0 or 1)
37
+ goalkeepers: sv.Detections
38
+ """
39
+ if len(players) == 0 or len(goalkeepers) == 0:
40
+ return np.array([0] * len(goalkeepers))
41
+
42
+ players_xy = players.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
43
+ goalkeepers_xy = goalkeepers.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
44
+
45
+ # Calculate centroids for Team 0 and Team 1
46
+ team_0_mask = players.class_id == 0
47
+ team_1_mask = players.class_id == 1
48
+
49
+ # Handle cases where one team might not be detected yet
50
+ if np.any(team_0_mask):
51
+ team_0_centroid = players_xy[team_0_mask].mean(axis=0)
52
+ else:
53
+ team_0_centroid = np.array([0, 0])
54
+
55
+ if np.any(team_1_mask):
56
+ team_1_centroid = players_xy[team_1_mask].mean(axis=0)
57
+ else:
58
+ team_1_centroid = np.array([10000, 10000]) # Far away
59
+
60
+ goalkeepers_team_id = []
61
+
62
+ for gk_xy in goalkeepers_xy:
63
+ dist_0 = np.linalg.norm(gk_xy - team_0_centroid)
64
+ dist_1 = np.linalg.norm(gk_xy - team_1_centroid)
65
+ goalkeepers_team_id.append(0 if dist_0 < dist_1 else 1)
66
+
67
+ return np.array(goalkeepers_team_id)
68
+
69
+ def get_final_assignments(self):
70
+ assignments = {}
71
+ for tid, votes in self.history.items():
72
+ if len(votes) > 0:
73
+ avg = sum(votes) / len(votes)
74
+ assignments[tid] = 1 if avg > 0.5 else 0
75
+ return assignments
76
+
wunderscout/vision.py ADDED
@@ -0,0 +1,155 @@
1
+ import torch
2
+ from ultralytics import YOLO
3
+ import supervision as sv
4
+ from transformers import AutoProcessor, SiglipVisionModel, data
5
+ from roboflow import Roboflow
6
+ from tqdm import tqdm
7
+ from more_itertools import chunked
8
+ import numpy as np
9
+ from pathlib import Path
10
+
11
+
12
+ class VisionEngine:
13
+ def __init__(self, player_weights, field_weights, device=None):
14
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
15
+ self.player_model = YOLO(player_weights)
16
+ self.field_model = YOLO(field_weights)
17
+
18
+ # Siglip for embeddings
19
+ siglip_path = "google/siglip-base-patch16-224"
20
+ self.siglip_model = SiglipVisionModel.from_pretrained(siglip_path).to(
21
+ self.device
22
+ )
23
+ self.siglip_processor = AutoProcessor.from_pretrained(siglip_path)
24
+
25
+ # --- Annotators ---
26
+ # Palette: 0=Blue, 1=Pink, 2=Yellow (Referee)
27
+ self.palette = sv.ColorPalette.from_hex(["#00BFFF", "#FF1493", "#FFD700"])
28
+
29
+ self.ellipse_annotator = sv.EllipseAnnotator(
30
+ color=self.palette,
31
+ thickness=2,
32
+ )
33
+ self.label_annotator = sv.LabelAnnotator(
34
+ color=self.palette,
35
+ text_color=sv.Color.from_hex("#000000"),
36
+ text_position=sv.Position.BOTTOM_CENTER,
37
+ )
38
+ self.triangle_annotator = sv.TriangleAnnotator(
39
+ color=sv.Color.from_hex("#FFD700"), base=25, height=21, outline_thickness=1
40
+ )
41
+
42
+ def get_calibration_crops(self, video_path, stride=30):
43
+ PLAYER_ID = 2
44
+ frame_generator = sv.get_video_frames_generator(
45
+ source_path=video_path, stride=stride
46
+ )
47
+
48
+ crops = []
49
+ for frame in frame_generator:
50
+ detections = self.detect_players(frame)
51
+ # Filter for players only for calibration
52
+ players = detections[detections.class_id == PLAYER_ID]
53
+ frame_crops = [sv.crop_image(frame, xyxy) for xyxy in players.xyxy]
54
+ crops += [sv.cv2_to_pillow(c) for c in frame_crops]
55
+
56
+ print(f"VisionEngine: Collected {len(crops)} calibration crops.")
57
+ return crops
58
+
59
+ def get_embeddings(self, pil_crops, batch_size=32):
60
+ batches = chunked(pil_crops, batch_size)
61
+ data_list = []
62
+
63
+ with torch.no_grad():
64
+ for batch in batches:
65
+ inputs = self.siglip_processor(images=batch, return_tensors="pt").to(
66
+ self.device
67
+ )
68
+ outputs = self.siglip_model(**inputs)
69
+ embeddings = torch.mean(outputs.last_hidden_state, dim=1).cpu().numpy()
70
+ data_list.append(embeddings)
71
+
72
+ return np.concatenate(data_list) if data_list else np.array([])
73
+
74
+ def detect_players(self, frame, conf=0.3):
75
+ result = self.player_model.predict(frame, conf=conf, verbose=False)[0]
76
+ return sv.Detections.from_ultralytics(result)
77
+
78
+ def detect_field(self, frame, conf=0.3):
79
+ result = self.field_model.predict(frame, conf=conf, verbose=False)[0]
80
+ return result
81
+
82
+ def draw_annotations(self, frame, all_detections, ball_detections):
83
+ annotated_frame = frame.copy()
84
+
85
+ # 1. Draw Ball
86
+ annotated_frame = self.triangle_annotator.annotate(
87
+ scene=annotated_frame, detections=ball_detections
88
+ )
89
+
90
+ # 2. Draw People (Players, GKs, Refs)
91
+ if len(all_detections) > 0:
92
+ # Ensure class_id is int for color mapping
93
+ all_detections.class_id = all_detections.class_id.astype(int)
94
+
95
+ labels = [f"#{tracker_id}" for tracker_id in all_detections.tracker_id]
96
+
97
+ annotated_frame = self.ellipse_annotator.annotate(
98
+ scene=annotated_frame, detections=all_detections
99
+ )
100
+ annotated_frame = self.label_annotator.annotate(
101
+ scene=annotated_frame, detections=all_detections, labels=labels
102
+ )
103
+
104
+ return annotated_frame
105
+
106
+
107
+ class ScoutingTrainer:
108
+ def __init__(self, api_key):
109
+ self.rf = Roboflow(api_key=api_key)
110
+
111
+ def train_players(
112
+ self,
113
+ workspace,
114
+ project,
115
+ version,
116
+ epochs=300,
117
+ output_dir="../runs/training/player",
118
+ ):
119
+ project = self.rf.workspace(workspace).project(project)
120
+ dataset = project.version(version).download("yolov11")
121
+ model = YOLO("../data/base_models/yolo11m.pt")
122
+
123
+ return model.train(
124
+ data=f"{dataset.location}/data.yaml",
125
+ epochs=epochs,
126
+ imgsz=1280,
127
+ plots=True,
128
+ device=0,
129
+ batch=2,
130
+ project=output_dir,
131
+ )
132
+
133
+ def train_field(
134
+ self,
135
+ workspace,
136
+ project,
137
+ version,
138
+ epochs=300,
139
+ output_dir="../runs/training/field",
140
+ ):
141
+ project = self.rf.workspace(workspace).project(project)
142
+ version = project.version(15)
143
+ dataset = version.download("yolov8", location="../data/data_sets/")
144
+ model = YOLO("yolo11m-pose.pt")
145
+
146
+ return model.train(
147
+ data=f"{dataset.location}/data.yaml",
148
+ save=True,
149
+ epochs=epochs,
150
+ plots=True,
151
+ imgsz=1080,
152
+ device=0,
153
+ batch=2,
154
+ project=output_dir,
155
+ )
@@ -0,0 +1,27 @@
1
+ Metadata-Version: 2.4
2
+ Name: wunderscout
3
+ Version: 0.1.2
4
+ Summary: Add your description here
5
+ Requires-Python: >=3.11
6
+ Requires-Dist: ipython>=9.5.0
7
+ Requires-Dist: matplotlib>=3.10.6
8
+ Requires-Dist: memory-profiler>=0.61.0
9
+ Requires-Dist: more-itertools>=10.8.0
10
+ Requires-Dist: networkx>=3.5
11
+ Requires-Dist: numba>=0.58
12
+ Requires-Dist: numpy>=2.3.2
13
+ Requires-Dist: opencv-python>=4.11.0.86
14
+ Requires-Dist: pandas>=2.3.2
15
+ Requires-Dist: plotly>=6.3.0
16
+ Requires-Dist: protobuf>=6.32.1
17
+ Requires-Dist: psutil>=7.0.0
18
+ Requires-Dist: python-dotenv>=1.1.1
19
+ Requires-Dist: roboflow>=1.2.7
20
+ Requires-Dist: scikit-learn>=1.7.2
21
+ Requires-Dist: seaborn>=0.13.2
22
+ Requires-Dist: sentencepiece>=0.2.1
23
+ Requires-Dist: supervision>=0.26.1
24
+ Requires-Dist: tqdm>=4.67.1
25
+ Requires-Dist: transformers>=4.56.1
26
+ Requires-Dist: ultralytics>=8.3.193
27
+ Requires-Dist: umap-learn>=0.5.9.post2
@@ -0,0 +1,12 @@
1
+ wunderscout/__init__.py,sha256=UwZMG43lwIAxHCHPg8SYS9A2Azkem9vcZw6tkwFG9kU,217
2
+ wunderscout/core.py,sha256=suVWCsiVmwsevte8tsX_53GEdZ_cQhKcFTknY20BWEw,6345
3
+ wunderscout/exporters.py,sha256=CRczFAYUS6EOedL84wq9WaFWTPKJyGxL7kbZQvLzLOA,1875
4
+ wunderscout/geometry.py,sha256=I5lt00O9jOiEoVpPGy5iVglzA7cgaAdUvzfuFBcJbRA,2197
5
+ wunderscout/heatmap.py,sha256=5R8Zw5Bnk-8eHTcWudo-1a3Mt83WrSbqYX9MkUoht_8,4268
6
+ wunderscout/main.py,sha256=mkiZcWgAuAAc8bUmKQXLkw--PX8dwrRR6qKHZqmyi5M,20008
7
+ wunderscout/pass_network.py,sha256=QC859Pi5VKSgHu1qrE3Zybvu97lNorsb03UAf1IrSbs,3099
8
+ wunderscout/teams.py,sha256=y0IclDACo3F8buVdpqqMCSmZJeWx2uqMkGNbZ6YToVc,2628
9
+ wunderscout/vision.py,sha256=fVX3wtwCwe6AiiGxZjH8u4q2gk3t4gCb4MNvmQf7Lhs,5257
10
+ wunderscout-0.1.2.dist-info/METADATA,sha256=LUmQrjy8MaxHjTmCFILvR95dSQ8kWTYfpxpJPWudjx4,841
11
+ wunderscout-0.1.2.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
12
+ wunderscout-0.1.2.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.28.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any