opensportslib 0.1.2.dev10__tar.gz → 0.1.2.dev11__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. {opensportslib-0.1.2.dev10/opensportslib.egg-info → opensportslib-0.1.2.dev11}/PKG-INFO +1 -1
  2. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/config/classification.yaml +14 -5
  3. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/core/trainer/classification_trainer.py +28 -10
  4. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/setup/setup.py +34 -10
  5. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11/opensportslib.egg-info}/PKG-INFO +1 -1
  6. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib.egg-info/SOURCES.txt +1 -0
  7. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/pyproject.toml +1 -1
  8. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/tests/conftest.py +24 -6
  9. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/tests/test_classification_dataset_paths.py +24 -2
  10. opensportslib-0.1.2.dev11/tests/test_classification_trainer_dataloader.py +127 -0
  11. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/tests/test_task_model_api_contract.py +5 -4
  12. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/LICENSE +0 -0
  13. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/LICENSE-COMMERCIAL +0 -0
  14. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/MANIFEST.in +0 -0
  15. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/README.md +0 -0
  16. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/examples/quickstart/basic_classification.py +0 -0
  17. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/examples/quickstart/basic_localization.py +0 -0
  18. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/__init__.py +0 -0
  19. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/apis/__init__.py +0 -0
  20. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/apis/base_task_model.py +0 -0
  21. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/apis/classification.py +0 -0
  22. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/apis/localization.py +0 -0
  23. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/cli.py +0 -0
  24. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/config/localization-e2e-ocv.yaml +0 -0
  25. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/config/localization-json_calf_resnetpca512.yaml +0 -0
  26. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/config/localization-json_netvlad++_resnetpca512.yaml +0 -0
  27. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/config/localization.yaml +0 -0
  28. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/config/sngar-frames.yaml +0 -0
  29. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/config/sngar-tracking.yaml +0 -0
  30. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/core/__init__.py +0 -0
  31. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/core/loss/__init__.py +0 -0
  32. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/core/loss/builder.py +0 -0
  33. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/core/loss/calf.py +0 -0
  34. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/core/loss/ce.py +0 -0
  35. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/core/loss/combine.py +0 -0
  36. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/core/loss/nll.py +0 -0
  37. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/core/optimizer/__init__.py +0 -0
  38. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/core/optimizer/builder.py +0 -0
  39. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/core/sampler/weighted_sampler.py +0 -0
  40. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/core/scheduler/__init__.py +0 -0
  41. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/core/scheduler/builder.py +0 -0
  42. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/core/trainer/__init__.py +0 -0
  43. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/core/trainer/localization_trainer.py +0 -0
  44. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/core/utils/checkpoint.py +0 -0
  45. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/core/utils/config.py +0 -0
  46. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/core/utils/data.py +0 -0
  47. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/core/utils/ddp.py +0 -0
  48. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/core/utils/default_args.py +0 -0
  49. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/core/utils/lightning.py +0 -0
  50. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/core/utils/load_annotations.py +0 -0
  51. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/core/utils/seed.py +0 -0
  52. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/core/utils/video_processing.py +0 -0
  53. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/core/utils/wandb.py +0 -0
  54. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/datasets/__init__.py +0 -0
  55. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/datasets/builder.py +0 -0
  56. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/datasets/classification_dataset.py +0 -0
  57. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/datasets/localization_dataset.py +0 -0
  58. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/datasets/utils/__init__.py +0 -0
  59. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/datasets/utils/tracking.py +0 -0
  60. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/metrics/classification_metric.py +0 -0
  61. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/metrics/localization_metric.py +0 -0
  62. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/models/__init__.py +0 -0
  63. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/models/backbones/builder.py +0 -0
  64. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/models/base/contextaware.py +0 -0
  65. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/models/base/e2e.py +0 -0
  66. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/models/base/learnablepooling.py +0 -0
  67. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/models/base/tracking.py +0 -0
  68. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/models/base/vars.py +0 -0
  69. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/models/base/video.py +0 -0
  70. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/models/base/video_mae.py +0 -0
  71. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/models/builder.py +0 -0
  72. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/models/heads/builder.py +0 -0
  73. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/models/neck/builder.py +0 -0
  74. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/models/utils/common.py +0 -0
  75. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/models/utils/impl/__init__.py +0 -0
  76. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/models/utils/impl/asformer.py +0 -0
  77. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/models/utils/impl/calf.py +0 -0
  78. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/models/utils/impl/gsm.py +0 -0
  79. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/models/utils/impl/gtad.py +0 -0
  80. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/models/utils/impl/tsm.py +0 -0
  81. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/models/utils/litebase.py +0 -0
  82. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/models/utils/modules.py +0 -0
  83. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/models/utils/shift.py +0 -0
  84. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/models/utils/utils.py +0 -0
  85. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/tools/__init__.py +0 -0
  86. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/tools/_common.py +0 -0
  87. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/tools/hf_transfer.py +0 -0
  88. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/tools/osl_json_to_parquet.py +0 -0
  89. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib/tools/parquet_to_osl_json.py +0 -0
  90. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib.egg-info/dependency_links.txt +0 -0
  91. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib.egg-info/entry_points.txt +0 -0
  92. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib.egg-info/requires.txt +0 -0
  93. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/opensportslib.egg-info/top_level.txt +0 -0
  94. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/setup.cfg +0 -0
  95. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/tests/test_config_utils_smoke.py +0 -0
  96. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/tests/test_conversion_tools.py +0 -0
  97. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/tests/test_hf_transfer_tools.py +0 -0
  98. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/tests/test_package_smoke.py +0 -0
  99. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/tests/test_public_apis_smoke.py +0 -0
  100. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/tests/test_subset_train_infer_integration.py +0 -0
  101. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/tools/convert/build_soccernet_gar.py +0 -0
  102. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/tools/convert/build_soccernet_gar_action_spotting.py +0 -0
  103. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/tools/convert/osl_json_to_parquet_webdataset.py +0 -0
  104. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/tools/convert/parquet_webdataset_to_osl_json.py +0 -0
  105. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/tools/download/download_hf_repo.py +0 -0
  106. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/tools/download/download_osl_hf.py +0 -0
  107. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/tools/download/upload_osl_hf.py +0 -0
  108. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/tools/training/classification.py +0 -0
  109. {opensportslib-0.1.2.dev10 → opensportslib-0.1.2.dev11}/tools/training/localization.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: opensportslib
3
- Version: 0.1.2.dev10
3
+ Version: 0.1.2.dev11
4
4
  Summary: OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data.
5
5
  Author: Jeet Vora
6
6
  Requires-Python: >=3.12
@@ -1,35 +1,44 @@
1
1
  TASK: classification
2
2
  DATA:
3
3
  dataset_name: mvfouls
4
- data_dir: /home/vorajv/opensportslib/SoccerNet/mvfouls
4
+ data_dir: /home/giancos/datasets/OpenSportsLab/OSL-XFoul/224p
5
5
  data_modality: video
6
6
  view_type: multi # multi or single
7
7
  num_classes: 8 # mvfoul
8
8
  train:
9
9
  type: annotations_train.json
10
10
  video_path: ${DATA.data_dir}/train
11
- path: ${DATA.train.video_path}/annotations-train.json
11
+ path: ${DATA.train.video_path}/train.json
12
12
  dataloader:
13
13
  batch_size: 8
14
14
  shuffle: true
15
15
  num_workers: 4
16
16
  pin_memory: true
17
+ mp_context: spawn
18
+ persistent_workers: true
19
+ prefetch_factor: 4
17
20
  valid:
18
21
  type: annotations_valid.json
19
22
  video_path: ${DATA.data_dir}/valid
20
- path: ${DATA.valid.video_path}/annotations-valid.json
23
+ path: ${DATA.valid.video_path}/valid.json
21
24
  dataloader:
22
25
  batch_size: 1
23
26
  num_workers: 1
24
27
  shuffle: false
28
+ mp_context: spawn
29
+ persistent_workers: true
30
+ prefetch_factor: 4
25
31
  test:
26
32
  type: annotations_test.json
27
33
  video_path: ${DATA.data_dir}/test
28
- path: ${DATA.test.video_path}/annotations-test.json
34
+ path: ${DATA.test.video_path}/test.json
29
35
  dataloader:
30
36
  batch_size: 1
31
- num_workers: 1
37
+ num_workers: 0
32
38
  shuffle: false
39
+ mp_context: spawn
40
+ persistent_workers: true
41
+ prefetch_factor: 4
33
42
  num_frames: 16 # 8 before + 8 after the foul
34
43
  input_fps: 25 # Original FPS of video
35
44
  target_fps: 17 # Temporal downsampling to 1s clip (approx)
@@ -14,6 +14,7 @@ import gc
14
14
  import json
15
15
  import time
16
16
  import logging
17
+ import multiprocessing as mp
17
18
 
18
19
  import torch
19
20
  import tqdm
@@ -865,36 +866,53 @@ class Trainer_Classification:
865
866
  else:
866
867
  val_sampler = None
867
868
 
868
- num_train_workers = self.config.DATA.train.dataloader.num_workers
869
- num_val_workers = self.config.DATA.valid.dataloader.num_workers
869
+ train_num_workers = getattr(self.config.DATA.train.dataloader, "num_workers", 0)
870
+ train_pin_memory = getattr(self.config.DATA.train.dataloader, "pin_memory", self.device.type == "cuda")
871
+ train_mp_context = getattr(self.config.DATA.train.dataloader, "mp_context", None)
872
+ train_persistent_workers = getattr(self.config.DATA.train.dataloader, "persistent_workers", train_num_workers > 0)
873
+ train_prefetch_factor = getattr(self.config.DATA.train.dataloader, "prefetch_factor", 4 if train_num_workers > 0 else None)
874
+
875
+ if train_mp_context is not None:
876
+ train_mp_context = mp.get_context(train_mp_context)
870
877
 
871
878
  train_loader = DataLoader(
872
879
  train_dataset,
873
880
  batch_size=self.config.DATA.train.dataloader.batch_size,
874
881
  shuffle=(train_sampler is None and shuffle),
875
882
  sampler=train_sampler,
876
- num_workers=num_train_workers,
877
- pin_memory=True,
883
+ num_workers=train_num_workers,
884
+ pin_memory=train_pin_memory,
878
885
  collate_fn=collate_fn,
879
886
  worker_init_fn=seed_worker,
880
887
  generator=g,
881
888
  drop_last=True,
882
- persistent_workers=num_train_workers > 0,
883
- prefetch_factor=4 if num_train_workers > 0 else None,
889
+ multiprocessing_context=train_mp_context,
890
+ persistent_workers=train_persistent_workers,
891
+ prefetch_factor=train_prefetch_factor,
884
892
  )
885
893
 
894
+ valid_num_workers = getattr(self.config.DATA.valid.dataloader, "num_workers", 0)
895
+ valid_pin_memory = getattr(self.config.DATA.valid.dataloader, "pin_memory", self.device.type == "cuda")
896
+ valid_mp_context = getattr(self.config.DATA.valid.dataloader, "mp_context", None)
897
+ valid_persistent_workers = getattr(self.config.DATA.valid.dataloader, "persistent_workers", valid_num_workers > 0)
898
+ valid_prefetch_factor = getattr(self.config.DATA.valid.dataloader, "prefetch_factor", 4 if valid_num_workers > 0 else None)
899
+
900
+ if valid_mp_context is not None:
901
+ valid_mp_context = mp.get_context(valid_mp_context)
902
+
886
903
  val_loader = DataLoader(
887
904
  val_dataset,
888
905
  batch_size=self.config.DATA.valid.dataloader.batch_size,
889
906
  shuffle=False,
890
907
  sampler=val_sampler,
891
- num_workers=num_val_workers,
892
- pin_memory=True,
908
+ num_workers=valid_num_workers,
909
+ pin_memory=valid_pin_memory,
893
910
  collate_fn=collate_fn,
894
911
  worker_init_fn=seed_worker,
895
912
  generator=g,
896
- persistent_workers=num_val_workers > 0,
897
- prefetch_factor=4 if num_val_workers > 0 else None,
913
+ multiprocessing_context=valid_mp_context,
914
+ persistent_workers=valid_persistent_workers,
915
+ prefetch_factor=valid_prefetch_factor,
898
916
  )
899
917
 
900
918
  # select the modality-specific trainer.
@@ -31,6 +31,17 @@ def get_cpu_tag():
31
31
  def install_torch():
32
32
  python = sys.executable
33
33
  subprocess.call([python, "-m", "pip", "uninstall", "-y", "torch", "torchvision"])
34
+
35
+ if CUDA_VERSION == "cu130":
36
+ cuda = "cu130"
37
+ subprocess.check_call([
38
+ python, "-m", "pip", "install",
39
+ "torch", "torchvision", "torchaudio",
40
+ "--index-url",
41
+ f"https://download.pytorch.org/whl/{cuda}"
42
+ ])
43
+ print(f"\nSuccess with {cuda}")
44
+ return cuda
34
45
  for cuda in CUDA_SUPPORT:
35
46
 
36
47
  print(f"\n Trying installation: {cuda}\n")
@@ -64,16 +75,29 @@ def install_dali():
64
75
 
65
76
  # DALI (only if GPU)
66
77
  if CUDA_VERSION:
67
- subprocess.check_call([
68
- python, "-m", "pip", "install",
69
- "nvidia-dali-cuda120"
70
- ])
71
-
72
- # CuPy (CUDA-aware but auto-resolves internally)
73
- subprocess.check_call([
74
- python, "-m", "pip", "install",
75
- "cupy-cuda12x"
76
- ])
78
+
79
+ if CUDA_VERSION == "cu130":
80
+ subprocess.check_call([
81
+ python, "-m", "pip", "install",
82
+ "nvidia-dali-cuda130"
83
+ ])
84
+
85
+ # CuPy (CUDA-aware but auto-resolves internally)
86
+ subprocess.check_call([
87
+ python, "-m", "pip", "install",
88
+ "cupy-cuda130"
89
+ ])
90
+ else:
91
+ subprocess.check_call([
92
+ python, "-m", "pip", "install",
93
+ "nvidia-dali-cuda120"
94
+ ])
95
+
96
+ # CuPy (CUDA-aware but auto-resolves internally)
97
+ subprocess.check_call([
98
+ python, "-m", "pip", "install",
99
+ "cupy-cuda12x"
100
+ ])
77
101
 
78
102
  def install_pyg():
79
103
  import torch
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: opensportslib
3
- Version: 0.1.2.dev10
3
+ Version: 0.1.2.dev11
4
4
  Summary: OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data.
5
5
  Author: Jeet Vora
6
6
  Requires-Python: >=3.12
@@ -88,6 +88,7 @@ opensportslib/tools/osl_json_to_parquet.py
88
88
  opensportslib/tools/parquet_to_osl_json.py
89
89
  tests/conftest.py
90
90
  tests/test_classification_dataset_paths.py
91
+ tests/test_classification_trainer_dataloader.py
91
92
  tests/test_config_utils_smoke.py
92
93
  tests/test_conversion_tools.py
93
94
  tests/test_hf_transfer_tools.py
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "opensportslib"
7
- version = "0.1.2.dev10"
7
+ version = "0.1.2.dev11"
8
8
  description = "OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data."
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.12"
@@ -185,8 +185,11 @@ def localization_integration_assets(tmp_path: Path) -> dict:
185
185
  "video_path": str(data_dir),
186
186
  "dataloader": {
187
187
  "batch_size": 1,
188
- "num_workers": 0,
188
+ "num_workers": 4,
189
189
  "pin_memory": False,
190
+ "mp_context": "spawn",
191
+ "persistent_workers": True,
192
+ "prefetch_factor": 4
190
193
  },
191
194
  },
192
195
  "valid": {
@@ -194,8 +197,11 @@ def localization_integration_assets(tmp_path: Path) -> dict:
194
197
  "video_path": str(data_dir),
195
198
  "dataloader": {
196
199
  "batch_size": 1,
197
- "num_workers": 0,
200
+ "num_workers": 1,
198
201
  "pin_memory": False,
202
+ "mp_context": "spawn",
203
+ "persistent_workers": True,
204
+ "prefetch_factor": 1
199
205
  },
200
206
  },
201
207
  "test": {
@@ -204,8 +210,11 @@ def localization_integration_assets(tmp_path: Path) -> dict:
204
210
  "results": str(result_path),
205
211
  "dataloader": {
206
212
  "batch_size": 1,
207
- "num_workers": 0,
213
+ "num_workers": 1,
208
214
  "pin_memory": False,
215
+ "mp_context": "spawn",
216
+ "persistent_workers": True,
217
+ "prefetch_factor": 1
209
218
  },
210
219
  },
211
220
  },
@@ -323,18 +332,27 @@ def localization_public_dataset_assets(tmp_path: Path) -> dict:
323
332
  "train": {
324
333
  "path": train_path,
325
334
  "video_path": str(data_dir / "train"),
326
- "dataloader": {"batch_size": 1, "num_workers": 0, "pin_memory": False},
335
+ "dataloader": {"batch_size": 1, "num_workers": 0, "pin_memory": False,
336
+ "mp_context": "spawn",
337
+ "persistent_workers": True,
338
+ "prefetch_factor": 4},
327
339
  },
328
340
  "valid": {
329
341
  "path": valid_path,
330
342
  "video_path": str(data_dir / "valid"),
331
- "dataloader": {"batch_size": 1, "num_workers": 0, "pin_memory": False},
343
+ "dataloader": {"batch_size": 1, "num_workers": 0, "pin_memory": False,
344
+ "mp_context": "spawn",
345
+ "persistent_workers": True,
346
+ "prefetch_factor": 1},
332
347
  },
333
348
  "test": {
334
349
  "path": test_path,
335
350
  "video_path": str(data_dir / "test"),
336
351
  "results": str(result_path),
337
- "dataloader": {"batch_size": 1, "num_workers": 0, "pin_memory": False},
352
+ "dataloader": {"batch_size": 1, "num_workers": 0, "pin_memory": False,
353
+ "mp_context": "spawn",
354
+ "persistent_workers": True,
355
+ "prefetch_factor": 1},
338
356
  },
339
357
  },
340
358
  "MODEL": {"backbone": {"type": "smoke_backbone"}, "multi_gpu": False},
@@ -2,6 +2,8 @@ import json
2
2
  from pathlib import Path
3
3
  from types import SimpleNamespace
4
4
 
5
+ import numpy as np
6
+
5
7
  import opensportslib.datasets.classification_dataset as classification_dataset
6
8
 
7
9
 
@@ -45,7 +47,7 @@ def _make_config(data_dir: Path, valid_video_root: Path) -> SimpleNamespace:
45
47
  valid=SimpleNamespace(video_path=str(valid_video_root)),
46
48
  test=SimpleNamespace(video_path=str(data_dir / "test_root")),
47
49
  ),
48
- MODEL=SimpleNamespace(type="custom"),
50
+ MODEL=SimpleNamespace(type="custom", pretrained_model="smoke_backbone"),
49
51
  )
50
52
 
51
53
 
@@ -57,8 +59,25 @@ def test_video_dataset_resolves_relative_paths_from_selected_split_root(
57
59
  data_dir = tmp_path / "dataset_root"
58
60
  valid_video_root = tmp_path / "separate_valid_root"
59
61
  config = _make_config(data_dir, valid_video_root)
62
+ captured = {}
60
63
 
61
64
  monkeypatch.setattr(classification_dataset, "build_transform", lambda config, mode: None)
65
+ monkeypatch.setattr(
66
+ classification_dataset,
67
+ "process_frames",
68
+ lambda *args, **kwargs: np.zeros((16, 4, 4, 3), dtype=np.uint8),
69
+ )
70
+ monkeypatch.setattr(
71
+ classification_dataset,
72
+ "get_transforms_model",
73
+ lambda model_name: (lambda tensor: tensor),
74
+ )
75
+
76
+ def fake_read_video(path):
77
+ captured["path"] = path
78
+ return []
79
+
80
+ monkeypatch.setattr(classification_dataset, "read_video", fake_read_video)
62
81
 
63
82
  dataset = classification_dataset.VideoDataset(
64
83
  config,
@@ -67,8 +86,11 @@ def test_video_dataset_resolves_relative_paths_from_selected_split_root(
67
86
  split="valid",
68
87
  )
69
88
 
70
- resolved_path = Path(dataset.samples[0]["video_paths"][0])
89
+ sample = dataset[0]
90
+
91
+ resolved_path = Path(captured["path"])
71
92
 
72
93
  assert dataset.split == "valid"
73
94
  assert resolved_path.is_absolute()
74
95
  assert resolved_path == valid_video_root / "clips" / "video_00000.mp4"
96
+ assert sample["id"] == "sample_00000"
@@ -0,0 +1,127 @@
1
+ from types import SimpleNamespace
2
+
3
+ import torch
4
+ from omegaconf import OmegaConf
5
+
6
+ from opensportslib.core.trainer import classification_trainer
7
+
8
+
9
+ class _FakeTrainer:
10
+ def __init__(self, **kwargs):
11
+ self.kwargs = kwargs
12
+
13
+ def train(self, epoch_start=0, save_every=1):
14
+ del epoch_start, save_every
15
+
16
+
17
+ class _FakeDataset:
18
+ label_map = {0: "PASS"}
19
+
20
+ def __len__(self):
21
+ return 1
22
+
23
+ def num_classes(self):
24
+ return 1
25
+
26
+ def get_class_weights(self, num_classes=None, sqrt=False):
27
+ del num_classes, sqrt
28
+ return torch.ones(1)
29
+
30
+ def get_sample_weights(self):
31
+ return torch.ones(1)
32
+
33
+
34
+ def _make_config(mp_context=None):
35
+ dataloader = {
36
+ "batch_size": 1,
37
+ "num_workers": 1,
38
+ "pin_memory": False,
39
+ "persistent_workers": True,
40
+ "prefetch_factor": 4,
41
+ }
42
+ if mp_context is not None:
43
+ dataloader["mp_context"] = mp_context
44
+
45
+ return OmegaConf.create(
46
+ {
47
+ "DATA": {
48
+ "data_modality": "video",
49
+ "train": {"dataloader": dict(dataloader)},
50
+ "valid": {"dataloader": dict(dataloader)},
51
+ },
52
+ "MODEL": {
53
+ "type": "custom",
54
+ "backbone": {"type": "smoke_backbone"},
55
+ },
56
+ "TRAIN": {
57
+ "use_weighted_loss": False,
58
+ "use_weighted_sampler": False,
59
+ "optimizer": {"type": "SGD", "lr": 0.1},
60
+ "scheduler": {"type": "StepLR", "step_size": 1, "gamma": 0.1},
61
+ "criterion": {"type": "CrossEntropyLoss"},
62
+ "epochs": 1,
63
+ "save_every": 1,
64
+ },
65
+ "SYSTEM": {
66
+ "seed": 0,
67
+ "device": "cpu",
68
+ "save_dir": ".",
69
+ },
70
+ }
71
+ )
72
+
73
+
74
+ def _run_train(monkeypatch, config):
75
+ dataloader_calls = []
76
+
77
+ monkeypatch.setattr(
78
+ classification_trainer,
79
+ "select_device",
80
+ lambda system: torch.device("cpu"),
81
+ )
82
+ monkeypatch.setattr(
83
+ classification_trainer,
84
+ "DataLoader",
85
+ lambda dataset, **kwargs: dataloader_calls.append(kwargs) or SimpleNamespace(),
86
+ )
87
+ monkeypatch.setattr(
88
+ classification_trainer,
89
+ "MVTrainerClassification",
90
+ _FakeTrainer,
91
+ )
92
+ monkeypatch.setattr(
93
+ "opensportslib.core.optimizer.builder.build_optimizer",
94
+ lambda params, cfg: object(),
95
+ )
96
+ monkeypatch.setattr(
97
+ "opensportslib.core.scheduler.builder.build_scheduler",
98
+ lambda optimizer, cfg: object(),
99
+ )
100
+ monkeypatch.setattr(
101
+ "opensportslib.core.loss.builder.build_criterion",
102
+ lambda cfg: object(),
103
+ )
104
+
105
+ trainer = classification_trainer.Trainer_Classification(config)
106
+ trainer.train(torch.nn.Linear(1, 1), _FakeDataset(), _FakeDataset())
107
+
108
+ return dataloader_calls
109
+
110
+
111
+ def test_video_train_loader_respects_explicit_spawn_context(monkeypatch):
112
+ dataloader_calls = _run_train(monkeypatch, _make_config("spawn"))
113
+
114
+ assert len(dataloader_calls) == 2
115
+ assert dataloader_calls[0]["num_workers"] == 1
116
+ assert dataloader_calls[0]["pin_memory"] is False
117
+ assert (dataloader_calls[0]["multiprocessing_context"].get_start_method() == "spawn")
118
+ assert (dataloader_calls[1]["multiprocessing_context"].get_start_method() == "spawn")
119
+
120
+
121
+
122
+ def test_video_train_loader_respects_explicit_context(monkeypatch):
123
+ dataloader_calls = _run_train(monkeypatch, _make_config("forkserver"))
124
+
125
+ assert len(dataloader_calls) == 2
126
+ assert (dataloader_calls[0]["multiprocessing_context"].get_start_method() == "forkserver")
127
+ assert (dataloader_calls[1]["multiprocessing_context"].get_start_method() == "forkserver")
@@ -208,7 +208,7 @@ def test_localization_evaluate_uses_provided_predictions(
208
208
  )
209
209
  monkeypatch.setattr(
210
210
  "opensportslib.core.utils.config.resolve_config_omega",
211
- lambda config: config,
211
+ lambda config, weights=None: config,
212
212
  )
213
213
  monkeypatch.setattr(
214
214
  "opensportslib.core.utils.load_annotations.check_config",
@@ -304,6 +304,7 @@ def test_localization_constructor_weights_are_default_for_train_and_infer(
304
304
  self.model = object()
305
305
  self.last_loaded_weights = weights
306
306
  self.best_checkpoint = weights
307
+ self._resume_state = {"source_weights": weights}
307
308
 
308
309
  def fake_build_trainer(cfg, model, default_args, resume_from=None):
309
310
  del cfg, model, default_args
@@ -329,7 +330,7 @@ def test_localization_constructor_weights_are_default_for_train_and_infer(
329
330
  )
330
331
  monkeypatch.setattr(
331
332
  "opensportslib.core.utils.config.resolve_config_omega",
332
- lambda config: config,
333
+ lambda config, weights=None: config,
333
334
  )
334
335
  monkeypatch.setattr(
335
336
  "opensportslib.core.utils.config.select_device",
@@ -367,9 +368,9 @@ def test_localization_constructor_weights_are_default_for_train_and_infer(
367
368
  train_api = LocalizationModel(config=localization_config_path, weights="default")
368
369
  train_api.config = make_config()
369
370
  train_api.train(use_wandb=False)
370
- assert trainer_resume_from[-1] == "default"
371
+ assert trainer_resume_from[-1]["source_weights"] == "default"
371
372
 
372
373
  train_api = LocalizationModel(config=localization_config_path, weights="default")
373
374
  train_api.config = make_config()
374
375
  train_api.train(weights="override", use_wandb=False)
375
- assert trainer_resume_from[-1] == "override"
376
+ assert trainer_resume_from[-1]["source_weights"] == "override"