opensportslib 0.0.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. opensportslib/__init__.py +18 -0
  2. opensportslib/apis/__init__.py +21 -0
  3. opensportslib/apis/classification.py +361 -0
  4. opensportslib/apis/localization.py +228 -0
  5. opensportslib/config/classification.yaml +104 -0
  6. opensportslib/config/classification_tracking.yaml +103 -0
  7. opensportslib/config/graph_tracking_classification/avgpool.yaml +79 -0
  8. opensportslib/config/graph_tracking_classification/gin.yaml +79 -0
  9. opensportslib/config/graph_tracking_classification/graphconv.yaml +79 -0
  10. opensportslib/config/graph_tracking_classification/graphsage.yaml +79 -0
  11. opensportslib/config/graph_tracking_classification/maxpool.yaml +79 -0
  12. opensportslib/config/graph_tracking_classification/noedges.yaml +79 -0
  13. opensportslib/config/localization.yaml +132 -0
  14. opensportslib/config/sngar_frames.yaml +98 -0
  15. opensportslib/core/__init__.py +0 -0
  16. opensportslib/core/loss/__init__.py +0 -0
  17. opensportslib/core/loss/builder.py +40 -0
  18. opensportslib/core/loss/calf.py +258 -0
  19. opensportslib/core/loss/ce.py +23 -0
  20. opensportslib/core/loss/combine.py +42 -0
  21. opensportslib/core/loss/nll.py +25 -0
  22. opensportslib/core/optimizer/__init__.py +0 -0
  23. opensportslib/core/optimizer/builder.py +38 -0
  24. opensportslib/core/sampler/weighted_sampler.py +104 -0
  25. opensportslib/core/scheduler/__init__.py +0 -0
  26. opensportslib/core/scheduler/builder.py +77 -0
  27. opensportslib/core/trainer/__init__.py +0 -0
  28. opensportslib/core/trainer/classification_trainer.py +1131 -0
  29. opensportslib/core/trainer/localization_trainer.py +1009 -0
  30. opensportslib/core/utils/checkpoint.py +238 -0
  31. opensportslib/core/utils/config.py +199 -0
  32. opensportslib/core/utils/data.py +85 -0
  33. opensportslib/core/utils/ddp.py +77 -0
  34. opensportslib/core/utils/default_args.py +110 -0
  35. opensportslib/core/utils/load_annotations.py +485 -0
  36. opensportslib/core/utils/seed.py +26 -0
  37. opensportslib/core/utils/video_processing.py +389 -0
  38. opensportslib/core/utils/wandb.py +110 -0
  39. opensportslib/datasets/__init__.py +0 -0
  40. opensportslib/datasets/builder.py +42 -0
  41. opensportslib/datasets/classification_dataset.py +582 -0
  42. opensportslib/datasets/localization_dataset.py +813 -0
  43. opensportslib/datasets/utils/__init__.py +15 -0
  44. opensportslib/datasets/utils/tracking.py +615 -0
  45. opensportslib/metrics/classification_metric.py +176 -0
  46. opensportslib/metrics/localization_metric.py +1482 -0
  47. opensportslib/models/__init__.py +0 -0
  48. opensportslib/models/backbones/builder.py +590 -0
  49. opensportslib/models/base/e2e.py +252 -0
  50. opensportslib/models/base/tracking.py +73 -0
  51. opensportslib/models/base/vars.py +29 -0
  52. opensportslib/models/base/video.py +130 -0
  53. opensportslib/models/base/video_mae.py +60 -0
  54. opensportslib/models/builder.py +43 -0
  55. opensportslib/models/heads/builder.py +266 -0
  56. opensportslib/models/neck/builder.py +210 -0
  57. opensportslib/models/utils/common.py +176 -0
  58. opensportslib/models/utils/impl/__init__.py +0 -0
  59. opensportslib/models/utils/impl/asformer.py +390 -0
  60. opensportslib/models/utils/impl/calf.py +74 -0
  61. opensportslib/models/utils/impl/gsm.py +112 -0
  62. opensportslib/models/utils/impl/gtad.py +347 -0
  63. opensportslib/models/utils/impl/tsm.py +123 -0
  64. opensportslib/models/utils/litebase.py +59 -0
  65. opensportslib/models/utils/modules.py +120 -0
  66. opensportslib/models/utils/shift.py +135 -0
  67. opensportslib/models/utils/utils.py +276 -0
  68. opensportslib-0.0.1.dev2.dist-info/METADATA +566 -0
  69. opensportslib-0.0.1.dev2.dist-info/RECORD +73 -0
  70. opensportslib-0.0.1.dev2.dist-info/WHEEL +5 -0
  71. opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE +661 -0
  72. opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE-COMMERCIAL +5 -0
  73. opensportslib-0.0.1.dev2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,582 @@
1
+ # opensportslib/datasets/classification_dataset.py
2
+
3
+ """classification dataset implmentations for video and tracking modalities.
4
+
5
+ provides three concrete dataset classes:
6
+
7
+ - VideoDataset (MVFoul, SN-GAR video)
8
+ - TrackingDataset (SN-GAR tracking / parquet)
9
+
10
+ both inherit from ClassificationDataset, which handles annotation loading,
11
+ label mapping, and class-weight computation.
12
+ """
13
+
14
+ import os
15
+ import random
16
+
17
+ import numpy as np
18
+ import torch
19
+ from torch.utils.data import Dataset
20
+ from tqdm import tqdm
21
+
22
+ from opensportslib.core.utils.load_annotations import load_annotations
23
+ from opensportslib.core.utils.video_processing import *
24
+
25
+
26
+ # -------------------------------------------------------------
27
+ # factory
28
+ # -------------------------------------------------------------
29
+
30
+ def build(config, annotations_path, processor=None, split="train"):
31
+ """construct the appropriate dataset for the configured modality.
32
+
33
+ Args:
34
+ config: the loaded YAML configuration.
35
+ annotations_path: path to the annotation JSON file.
36
+ processor: HuggingFace image processor (video modality only).
37
+ split: one of "train", "valid", "test".
38
+
39
+ Returns:
40
+ a Dataset instance (VideoDataset or TrackingDataset).
41
+
42
+ Raises:
43
+ ValueError: if the data_modality is not recognized.
44
+ """
45
+ modality = config.DATA.data_modality.lower()
46
+
47
+ if modality == "tracking_parquet":
48
+ return TrackingDataset(config, annotations_path, split)
49
+ elif modality in ("video", "frames_npy"):
50
+ return VideoDataset(config, annotations_path, processor, split)
51
+ else:
52
+ raise ValueError(f"Unknown data_modality: {modality}")
53
+
54
+
55
+ # -------------------------------------------------------------
56
+ # base class
57
+ # -------------------------------------------------------------
58
+
59
+ class ClassificationDataset(Dataset):
60
+ """shared base for all classification datasets.
61
+
62
+ loads annotations, builds a label map, and exposes helpers for
63
+ computing sample-level and class-level weights (useful for
64
+ balanced sampling and weighted loss respectively).
65
+
66
+ Args:
67
+ config: the loaded YAML configuration.
68
+ annotations_path: path to the annotation JSON file.
69
+ split: one of "train", "valid", "test".
70
+ """
71
+
72
+ def __init__(self, config, annotations_path, processor, split="train"):
73
+ self.config = config
74
+ self.split = split
75
+ self.exclude_labels = ["Unknown", "Dont know"]
76
+ self.data_dir = config.DATA.data_dir
77
+ self.processor = None
78
+
79
+ # view_type is optional; only MVFoul uses it as of now
80
+ is_multiview = getattr(config.DATA, "view_type", None) == "multi"
81
+
82
+ allow_missing_labels = split in ["test", "infer"]
83
+
84
+ self.samples, self.label_map = load_annotations(
85
+ annotations_path,
86
+ exclude_labels=self.exclude_labels,
87
+ multiview=is_multiview,
88
+ input_type=config.DATA.data_modality,
89
+ allow_missing_labels=allow_missing_labels
90
+ )
91
+
92
+ max_samples = getattr(config.DATA, 'max_samples', None)
93
+ if max_samples:
94
+ self.samples = self.samples[:max_samples]
95
+
96
+ # invert to id -> name and propagate into the config so
97
+ # downstream components (metrics, logging) can look it up.
98
+ self.label_map = {v: k for k, v in self.label_map.items()}
99
+ self.config.DATA.classes = list(self.label_map.values())
100
+ self.config.DATA.num_classes = len(self.label_map)
101
+
102
+ print(self.config.DATA.num_classes, "classes:", self.config.DATA.classes)
103
+ print("Label Map : ", self.label_map)
104
+
105
+ self.has_labels = len(self.samples) > 0 and "label" in self.samples[0]
106
+
107
+ # -- Sampling / loss weights ------------------------------------------
108
+
109
+ def get_sample_weights(self):
110
+ """per-sample inverse-frequency weights for WeightedRandomSampler.
111
+
112
+ Returns:
113
+ torch.Tensor of length len(self) with one weight per sample.
114
+ """
115
+ labels = [item["label"] for item in self.samples]
116
+
117
+ class_counts = torch.bincount(torch.tensor(labels))
118
+ class_weights = 1.0 / class_counts.float()
119
+ sample_weights = torch.tensor(
120
+ [class_weights[label] for label in labels],
121
+ dtype=torch.float
122
+ )
123
+
124
+ return sample_weights
125
+
126
+
127
+ def get_class_weights(self, num_classes=None, normalize=True, sqrt=False):
128
+ """per-class inverse-frequency weights for WeightedRandomSampler.
129
+
130
+ Args:
131
+ num_classes: if None, inferred from the label tensor.
132
+ normalize: if True, weights are scaled so they sum to num_classes.
133
+ sqrt: if True, use inverse square-root frequency instead of raw counts.
134
+
135
+ Returns:
136
+ torch.Tensor of shape (num_classes,).
137
+ """
138
+ labels = torch.tensor([item["label"] for item in self.samples])
139
+
140
+ if num_classes is None:
141
+ num_classes = int(labels.max().item() + 1)
142
+
143
+ counts = torch.bincount(labels, minlength=num_classes).float()
144
+ counts[counts == 0] = 1.0 # avoid division by zero for unseen classes
145
+
146
+ weights = 1.0 / torch.sqrt(counts) if sqrt else 1.0 / counts
147
+
148
+ if normalize:
149
+ weights = weights / weights.sum() * num_classes
150
+
151
+ return weights
152
+
153
+ def __len__(self):
154
+ return len(self.samples)
155
+
156
+ def num_classes(self):
157
+ return len(self.label_map)
158
+
159
+
160
+ # -------------------------------------------------------------
161
+ # video modality
162
+ # -------------------------------------------------------------
163
+
164
+ class VideoDataset(ClassificationDataset):
165
+ """frame-sampled video clips for classification.
166
+
167
+ for MVFoul: supports single-view and multi-view modes. In multi-view
168
+ training, two views are randomly sampled per clip; at test time all available
169
+ views are returned and stacked along a view dimension.
170
+
171
+ Args:
172
+ config: the loaded YAML configuration.
173
+ annotations_path: path to the annotation JSON file.
174
+ processor: HuggingFace image processor (used only for HuggingFace models).
175
+ split: one of "train", "valid", "test".
176
+ """
177
+
178
+ def __init__(self, config, annotations_path, processor, split="train"):
179
+ super().__init__(config, annotations_path, split)
180
+
181
+ self.processor = processor
182
+ self.view_type = getattr(config.DATA, "view_type", "single")
183
+ self.num_frames = getattr(config.DATA, "num_frames", None)
184
+ self.input_fps = getattr(config.DATA, "input_fps", None)
185
+ self.transform = build_transform(config, mode=self.split)
186
+
187
+ def _select_views(self, video_paths):
188
+ """ choose which camera views to laod for this sample.
189
+
190
+ Args:
191
+ video_paths: list of available view paths for the clip.
192
+
193
+ Returns:
194
+ a (possibly subsampled) list of paths.
195
+
196
+ """
197
+ if self.view_type == "single":
198
+ return [video_paths[0]]
199
+
200
+ if self.split.lower() == "train" and self.view_type == "multi":
201
+ return random.sample(video_paths, min(2, len(video_paths)))
202
+
203
+ return video_paths
204
+
205
+ def _load_and_sample_clip(self, path):
206
+ """read a video file, temporally sub-sample, and apply transforms.
207
+
208
+ Args:
209
+ path: realtive path (under data_dir) to the video file.
210
+
211
+ Returns:
212
+ numpy.ndarray of shape (T, H, W, C).
213
+ """
214
+ full_path = os.path.join(self.config.DATA.data_dir, path)
215
+
216
+ if full_path.endswith(".npy"):
217
+ frames = np.load(full_path).astype(np.float32) / 255.0
218
+ if self.transform is not None:
219
+ frames = self.transform(frames)
220
+ mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
221
+ std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
222
+ frames = (frames - mean) / std
223
+ return frames
224
+
225
+ v = read_video(os.path.join(self.config.DATA.data_dir, path))
226
+
227
+ v = process_frames(
228
+ v,
229
+ self.num_frames,
230
+ self.input_fps,
231
+ self.config.DATA.target_fps,
232
+ start_frame=self.config.DATA.start_frame,
233
+ end_frame=self.config.DATA.end_frame
234
+ )
235
+
236
+ if isinstance(v, list):
237
+ v = np.stack(v) # (T, H, W, C)
238
+
239
+ if self.transform is not None:
240
+ v = self.transform(v)
241
+
242
+ return v # (T, H, W, C)
243
+
244
+ def __getitem__(self, idx):
245
+ item = self.samples[idx]
246
+ label = item.get("label", None)
247
+ if label is not None:
248
+ label = torch.tensor(label, dtype=torch.long)
249
+ video_paths = item["video_paths"]
250
+ sample_id = item["id"]
251
+
252
+ # --- Choose which clips to load ---
253
+ if not video_paths:
254
+ raise ValueError(f"No video paths found for item {idx}")
255
+
256
+ selected_paths = self._select_views(video_paths)
257
+
258
+ # --- Load and process frames for selected clips ---
259
+ if self.config.MODEL.type == "huggingface":
260
+ path = selected_paths[0]
261
+ v = self._load_and_sample_clip(path)
262
+ # convert clip -> list of frames
263
+ v = list(v) # each element is (H, W, C)
264
+
265
+ #print(type(v), v)
266
+ v = self.processor(v, return_tensors="pt")#, do_rescale=False)
267
+ pixel_values = v["pixel_values"].float()
268
+ pixel_values = pixel_values.squeeze(0)
269
+ out = {"pixel_values": pixel_values, "id": sample_id}
270
+ if label is not None:
271
+ out["labels"] = label
272
+ return out
273
+
274
+ else:
275
+ view_tensors = []
276
+ for path in selected_paths:
277
+ v = self._load_and_sample_clip(path)
278
+
279
+ if path.endswith(".npy"):
280
+ # frames_npy
281
+ v = torch.from_numpy(v) # (T, H, W, C)
282
+ else:
283
+ # existing raw video path: apply torchvision model transforms
284
+ v = torch.from_numpy(v).permute(0, 3, 1, 2) # (T, C, H, W)
285
+ v = get_transforms_model(self.config.MODEL.pretrained_model)(v) # (C, T, H, W)
286
+
287
+ view_tensors.append(v)
288
+
289
+ if selected_paths[0].endswith(".npy"):
290
+ # frames_npy: single view, return (T, H, W, C) matching pixels_vs_positions
291
+ out = {"pixel_values": view_tensors[0], "id": sample_id}
292
+ else:
293
+ # existing multi-view path: stack to (V, C, T, H, W)
294
+ videos = torch.stack(view_tensors, dim=0)
295
+ out = {"pixel_values": videos, "id": sample_id}
296
+
297
+ if label is not None:
298
+ out["labels"] = label
299
+ return out
300
+
301
+
302
+ # -------------------------------------------------------------
303
+ # tracking modality
304
+ # -------------------------------------------------------------
305
+
306
+ class TrackingDataset(ClassificationDataset):
307
+ """graph-based classification dataset built from player tracking data.
308
+
309
+ each sample is a temporal sequence of per-frame graphs where nodes
310
+ represent the ball and 22 players, and edges encode spatial or
311
+ tactical relationships (see build_edge_index).
312
+
313
+ supports optional preloading of all clips into memory and
314
+ training-time augmentation (horizontal/vertical/team flip).
315
+
316
+ Args:
317
+ config: the loaded YAML configuration.
318
+ annotations_path: path to the annotation JSON file.
319
+ split: one of "train", "valid", "test".
320
+ """
321
+
322
+ def __init__(self, config, annotations_path, split="train"):
323
+ super().__init__(config, annotations_path, split)
324
+
325
+ from opensportslib.datasets.utils.tracking import (
326
+ FEATURE_DIM,
327
+ NUM_OBJECTS,
328
+ HorizontalFlip,
329
+ TeamFlip,
330
+ VerticalFlip,
331
+ build_edge_index,
332
+ )
333
+
334
+ # storing references for the constants without repeating the import.
335
+ self._NUM_OBJECTS = NUM_OBJECTS
336
+ self._FEATURE_DIM = FEATURE_DIM
337
+ self._build_edge_index = build_edge_index
338
+
339
+ self.num_frames = config.DATA.num_frames
340
+ self.normalize = config.DATA.normalize
341
+ self.edge_type = config.MODEL.edge
342
+ self.k = config.MODEL.k
343
+ self.r = config.MODEL.r
344
+ self.preload_data = config.DATA.preload_data
345
+
346
+ self.transforms = self._build_transforms(
347
+ config, split, HorizontalFlip, VerticalFlip, TeamFlip
348
+ )
349
+
350
+ self.processed_samples = None
351
+ if self.preload_data:
352
+ self._preload_all_data()
353
+
354
+ @staticmethod
355
+ def _build_transforms(config, split, HorizontalFlip, VerticalFlip, TeamFlip):
356
+ """assmeble the list of training-time augmentations.
357
+
358
+ Args:
359
+ config: the loaded YAML configuration.
360
+ split: dataset split; augmentations are only applied during training.
361
+ HorizontalFlip: augmentation class (passed to avoid re-importing).
362
+ VerticalFlip: augmentation class.
363
+ TeamFlip: augmentation class.
364
+
365
+ Returns:
366
+ list of callable augmentation transforms (empty for
367
+ non-training splits).
368
+ """
369
+ if split != "train":
370
+ return []
371
+
372
+ transforms = []
373
+ aug_config = config.DATA.augmentations
374
+
375
+ # augmentation flags are optional in the config; default to off.
376
+ if getattr(aug_config, "horizontal_flip", False):
377
+ transforms.append(HorizontalFlip(probability=0.5))
378
+
379
+ if getattr(aug_config, "vertical_flip", False):
380
+ transforms.append(VerticalFlip(probability=0.5))
381
+
382
+ if getattr(aug_config, "team_flip", False):
383
+ transforms.append(TeamFlip(probability=0.5))
384
+
385
+ return transforms
386
+
387
+ def _preload_all_data(self):
388
+ """parse every clip and cache features and edge indices in memory.
389
+
390
+ edge indices are built on raw (un-augmented, un-normalized)
391
+ features so that the graph topology is deterministic and
392
+ augmentation-independent.
393
+ """
394
+ from tqdm import tqdm
395
+
396
+ from opensportslib.datasets.utils.tracking import (
397
+ compute_deltas,
398
+ parse_frame
399
+ )
400
+
401
+ print(f"Preloading {len(self.samples)} {self.split} samples into memory...")
402
+
403
+ self.processed_samples = []
404
+
405
+ for item in tqdm(self.samples, desc=f"Loading {self.split}"):
406
+ clip_paths = item["video_paths"]
407
+ if not clip_paths:
408
+ continue
409
+
410
+ clip_path = clip_paths[0]
411
+ df = self._load_tracking_clip(clip_path)
412
+
413
+ num_frames = len(df)
414
+ all_features = np.zeros(
415
+ (num_frames, self._NUM_OBJECTS, self._FEATURE_DIM),
416
+ dtype=np.float32
417
+ )
418
+ all_positions = []
419
+
420
+ for t, (_, row) in enumerate(df.iterrows()):
421
+ features, positions = parse_frame(row)
422
+ all_features[t] = features
423
+ all_positions.append(positions)
424
+
425
+ all_features = compute_deltas(all_features)
426
+
427
+ # build edge indices on raw features (before any augmentation
428
+ # or normalization) so the graph topology stays consistent
429
+ # regardless of transforms.
430
+ edge_indices = []
431
+ for t in range(num_frames):
432
+ edge_index = self.build_edge_index(
433
+ all_features[t],
434
+ all_positions[t],
435
+ self.edge_type,
436
+ self.k,
437
+ self.r
438
+ )
439
+ edge_indices.append(edge_index)
440
+
441
+ self.processed_samples.append({
442
+ "features": all_features,
443
+ "positions": all_positions,
444
+ "edge_indices": edge_indices,
445
+ "label": item["label"],
446
+ "id": item["id"]
447
+ })
448
+
449
+ print(f"Loaded {len(self.processed_samples)} {self.split} samples")
450
+
451
+ def _load_tracking_clip(self, path):
452
+ """read a single parquet tracking clip.
453
+
454
+ Args:
455
+ path: Relative path (under ``data_dir``) to the parquet
456
+ file.
457
+
458
+ Returns:
459
+ ``pandas.DataFrame`` with one row per frame.
460
+ """
461
+ import pandas as pd
462
+
463
+ full_path = os.path.join(self.data_dir, path)
464
+ return pd.read_parquet(full_path)
465
+
466
+ def __getitem__(self, idx):
467
+ if self.preload_data:
468
+ return self._getitem_preloaded(idx)
469
+ else:
470
+ return self._getitem_on_the_fly(idx)
471
+
472
+ def _getitem_preloaded(self, idx):
473
+ """return a sample from the in-memory cache.
474
+
475
+ a copy of the feature array is made before augmentation and
476
+ normalization so the cached data is never mutated.
477
+ """
478
+ from torch_geometric.data import Data
479
+
480
+ from opensportslib.datasets.utils.tracking import normalize_features
481
+
482
+ sample = self.processed_samples[idx]
483
+ features = sample["features"].copy()
484
+
485
+ for transform in self.transforms:
486
+ features = transform(features)
487
+
488
+ if self.normalize:
489
+ features = normalize_features(features)
490
+
491
+ # build one PyG Data object per frame. The downstream collate function
492
+ # (tracking_collate) uses PyG Batch.from_data_list to merge these across
493
+ # the batch dimension.
494
+ graphs = []
495
+ for t in range(features.shape[0]):
496
+ data = Data(
497
+ x=torch.tensor(features[t], dtype=torch.float),
498
+ edge_index=torch.tensor(
499
+ sample["edge_indices"][t], dtype=torch.long
500
+ ),
501
+ )
502
+ graphs.append(data)
503
+
504
+ out = {
505
+ "graphs": graphs,
506
+ "seq_len": len(graphs),
507
+ "id": sample["id"]
508
+ }
509
+ if "label" in sample:
510
+ out["label"] = sample["label"]
511
+ return out
512
+
513
+ def _getitem_on_the_fly(self, idx):
514
+ """load, parse, and process a single sample from disk."""
515
+ from torch_geometric.data import Data
516
+
517
+ from opensportslib.datasets.utils.tracking import (
518
+ compute_deltas,
519
+ normalize_features,
520
+ parse_frame,
521
+ )
522
+
523
+ item = self.samples[idx]
524
+ label = item["label"]
525
+
526
+ clip_paths = item["video_paths"]
527
+ if not clip_paths:
528
+ raise ValueError(f"No tracking paths found for item {idx}")
529
+
530
+ clip_path = clip_paths[0]
531
+ df = self._load_tracking_clip(clip_path)
532
+
533
+ num_frames = len(df)
534
+ all_features = np.zeros(
535
+ (num_frames, self._NUM_OBJECTS, self._FEATURE_DIM),
536
+ dtype=np.float32
537
+ )
538
+ all_positions = []
539
+
540
+ for t, (_, row) in enumerate(df.iterrows()):
541
+ features, positions = parse_frame(row)
542
+ all_features[t] = features
543
+ all_positions.append(positions)
544
+
545
+ all_features = compute_deltas(all_features)
546
+
547
+ # edge indices are built on raw features (before any augmentation /
548
+ # normalization) so the graph structure is augmentation-invariant.
549
+ edge_indices = []
550
+ for t in range(num_frames):
551
+ edge_index = self._build_edge_index(
552
+ all_features[t],
553
+ all_positions[t],
554
+ self.edge_type,
555
+ self.k,
556
+ self.r
557
+ )
558
+ edge_indices.append(edge_index)
559
+
560
+ for transform in self.transforms:
561
+ all_features = transform(all_features)
562
+
563
+ if self.normalize:
564
+ all_features = normalize_features(all_features)
565
+
566
+ graphs = []
567
+ for t in range(num_frames):
568
+ data = Data(
569
+ x=torch.tensor(all_features[t], dtype=torch.float),
570
+ edge_index=torch.tensor(edge_indices[t], dtype=torch.long),
571
+ )
572
+ graphs.append(data)
573
+
574
+ out = {
575
+ "graphs": graphs,
576
+ "seq_len": len(graphs),
577
+ "id": item["id"]
578
+ }
579
+ if "label" in label:
580
+ out["label"] = label
581
+ return out
582
+