opensportslib 0.1.3.dev3__tar.gz → 0.1.3.dev4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {opensportslib-0.1.3.dev3/opensportslib.egg-info → opensportslib-0.1.3.dev4}/PKG-INFO +4 -3
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/README.md +3 -2
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/apis/base_task_model.py +18 -9
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/apis/classification.py +63 -43
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/apis/localization.py +61 -61
- opensportslib-0.1.3.dev4/opensportslib/configs/captioning/encoder_decoder/video_captioning.yaml +136 -0
- opensportslib-0.1.3.dev4/opensportslib/configs/captioning/llava/llava_style.yaml +148 -0
- opensportslib-0.1.3.dev4/opensportslib/configs/classification/default.yaml +125 -0
- opensportslib-0.1.3.dev4/opensportslib/configs/classification/frames_npy/sngar-frames.yaml +171 -0
- opensportslib-0.1.3.dev4/opensportslib/configs/classification/tracking/sngar-tracking.yaml +170 -0
- opensportslib-0.1.3.dev4/opensportslib/configs/classification/video/classification.yaml +178 -0
- opensportslib-0.1.3.dev4/opensportslib/configs/localization/default.yaml +134 -0
- opensportslib-0.1.3.dev4/opensportslib/configs/localization/video/localization-dali.yaml +180 -0
- opensportslib-0.1.3.dev4/opensportslib/configs/localization/video/localization-ocv.yaml +189 -0
- opensportslib-0.1.3.dev4/opensportslib/configs/localization/video_features/localization-calf-resnetpca512.yaml +201 -0
- opensportslib-0.1.3.dev4/opensportslib/configs/localization/video_features/localization-netvladpp-resnetpca512.yaml +192 -0
- opensportslib-0.1.3.dev4/opensportslib/configs/reasoning/multimodal/video_text_fusion.yaml +161 -0
- opensportslib-0.1.3.dev4/opensportslib/configs/retrieval/two_tower/video_text_retrieval.yaml +150 -0
- opensportslib-0.1.3.dev4/opensportslib/core/config/__init__.py +22 -0
- opensportslib-0.1.3.dev4/opensportslib/core/config/accessors.py +374 -0
- opensportslib-0.1.3.dev4/opensportslib/core/config/conflicts.py +130 -0
- opensportslib-0.1.3.dev4/opensportslib/core/config/loader.py +77 -0
- opensportslib-0.1.3.dev4/opensportslib/core/config/migrate.py +22 -0
- opensportslib-0.1.3.dev4/opensportslib/core/config/migrations/__init__.py +5 -0
- opensportslib-0.1.3.dev4/opensportslib/core/config/migrations/legacy_to_canonical.py +614 -0
- opensportslib-0.1.3.dev4/opensportslib/core/config/runtime_adapter.py +60 -0
- opensportslib-0.1.3.dev4/opensportslib/core/config/schema.py +61 -0
- opensportslib-0.1.3.dev4/opensportslib/core/config/schemas/__init__.py +6 -0
- opensportslib-0.1.3.dev4/opensportslib/core/config/schemas/schema_canonical.py +22 -0
- opensportslib-0.1.3.dev4/opensportslib/core/config/schemas/schema_legacy.py +17 -0
- opensportslib-0.1.3.dev4/opensportslib/core/config/validate.py +89 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/loss/builder.py +6 -2
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/optimizer/builder.py +4 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/sampler/weighted_sampler.py +8 -5
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/scheduler/builder.py +3 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/trainer/classification_trainer.py +138 -82
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/trainer/localization_trainer.py +81 -49
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/utils/config.py +51 -42
- opensportslib-0.1.3.dev4/opensportslib/core/utils/config_normalize.py +29 -0
- opensportslib-0.1.3.dev4/opensportslib/core/utils/default_args.py +143 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/utils/load_annotations.py +52 -30
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/utils/video_processing.py +10 -4
- opensportslib-0.1.3.dev4/opensportslib/core/utils/wandb.py +217 -0
- opensportslib-0.1.3.dev4/opensportslib/datasets/builder.py +16 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/datasets/classification_dataset.py +140 -44
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/datasets/localization_dataset.py +243 -103
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/datasets/utils/tracking.py +18 -12
- {opensportslib-0.1.3.dev3/opensportslib/config → opensportslib-0.1.3.dev4/opensportslib/legacy_config}/classification.yaml +4 -4
- {opensportslib-0.1.3.dev3/opensportslib/config → opensportslib-0.1.3.dev4/opensportslib/legacy_config}/localization-e2e-ocv.yaml +1 -1
- {opensportslib-0.1.3.dev3/opensportslib/config → opensportslib-0.1.3.dev4/opensportslib/legacy_config}/localization-json_calf_resnetpca512.yaml +1 -1
- {opensportslib-0.1.3.dev3/opensportslib/config → opensportslib-0.1.3.dev4/opensportslib/legacy_config}/localization-json_netvlad++_resnetpca512.yaml +1 -1
- {opensportslib-0.1.3.dev3/opensportslib/config → opensportslib-0.1.3.dev4/opensportslib/legacy_config}/localization.yaml +5 -6
- {opensportslib-0.1.3.dev3/opensportslib/config → opensportslib-0.1.3.dev4/opensportslib/legacy_config}/sngar-frames.yaml +4 -4
- {opensportslib-0.1.3.dev3/opensportslib/config → opensportslib-0.1.3.dev4/opensportslib/legacy_config}/sngar-tracking.yaml +5 -5
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/metrics/localization_metric.py +15 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/backbones/builder.py +1 -1
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/base/contextaware.py +36 -17
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/base/learnablepooling.py +44 -22
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/base/tracking.py +24 -9
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/base/vars.py +7 -2
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/base/video.py +41 -11
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/base/video_mae.py +20 -5
- opensportslib-0.1.3.dev4/opensportslib/models/builder.py +145 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4/opensportslib.egg-info}/PKG-INFO +4 -3
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib.egg-info/SOURCES.txt +37 -7
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/pyproject.toml +2 -2
- opensportslib-0.1.3.dev4/tests/conftest.py +581 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/tests/test_classification_dataset_paths.py +32 -12
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/tests/test_classification_trainer_dataloader.py +47 -11
- opensportslib-0.1.3.dev4/tests/test_config_architecture.py +104 -0
- opensportslib-0.1.3.dev4/tests/test_config_split_override_sync.py +37 -0
- opensportslib-0.1.3.dev4/tests/test_pretrained_config_merge_policy.py +223 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/tests/test_public_apis_smoke.py +4 -4
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/tests/test_task_model_api_contract.py +35 -9
- opensportslib-0.1.3.dev3/opensportslib/core/utils/default_args.py +0 -110
- opensportslib-0.1.3.dev3/opensportslib/core/utils/wandb.py +0 -280
- opensportslib-0.1.3.dev3/opensportslib/datasets/builder.py +0 -42
- opensportslib-0.1.3.dev3/opensportslib/models/builder.py +0 -66
- opensportslib-0.1.3.dev3/tests/conftest.py +0 -377
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/LICENSE +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/LICENSE-COMMERCIAL +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/MANIFEST.in +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/examples/quickstart/basic_classification.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/examples/quickstart/basic_localization.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/__init__.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/apis/__init__.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/cli.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/__init__.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/loss/__init__.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/loss/calf.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/loss/ce.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/loss/combine.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/loss/nll.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/optimizer/__init__.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/scheduler/__init__.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/trainer/__init__.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/utils/checkpoint.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/utils/data.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/utils/ddp.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/utils/lightning.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/core/utils/seed.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/datasets/__init__.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/datasets/utils/__init__.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/metrics/classification_metric.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/__init__.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/base/e2e.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/heads/builder.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/neck/builder.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/utils/common.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/utils/impl/__init__.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/utils/impl/asformer.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/utils/impl/calf.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/utils/impl/gsm.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/utils/impl/gtad.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/utils/impl/tsm.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/utils/litebase.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/utils/modules.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/utils/shift.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/models/utils/utils.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/setup/setup.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/tools/__init__.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/tools/_common.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/tools/hf_transfer.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/tools/osl_json_to_parquet.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib/tools/parquet_to_osl_json.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib.egg-info/dependency_links.txt +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib.egg-info/entry_points.txt +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib.egg-info/requires.txt +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/opensportslib.egg-info/top_level.txt +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/setup.cfg +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/tests/test_config_utils_smoke.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/tests/test_conversion_tools.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/tests/test_hf_transfer_tools.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/tests/test_localization_dali_filenames.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/tests/test_package_smoke.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/tests/test_subset_train_infer_integration.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/tools/convert/build_soccernet_gar.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/tools/convert/build_soccernet_gar_action_spotting.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/tools/convert/osl_json_to_parquet_webdataset.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/tools/convert/parquet_webdataset_to_osl_json.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/tools/download/download_hf_repo.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/tools/download/download_osl_hf.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/tools/download/upload_osl_hf.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/tools/training/classification.py +0 -0
- {opensportslib-0.1.3.dev3 → opensportslib-0.1.3.dev4}/tools/training/localization.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: opensportslib
|
|
3
|
-
Version: 0.1.3.
|
|
3
|
+
Version: 0.1.3.dev4
|
|
4
4
|
Summary: OpenSportsLib is the professional library, designed for advanced video understanding in sports. It provides state-of-the-art tools for action recognition, spotting, retrieval, and captioning, making it ideal for researchers, analysts, and developers working with sports video data.
|
|
5
5
|
Author: Jeet Vora
|
|
6
6
|
Requires-Python: >=3.12
|
|
@@ -34,6 +34,7 @@ Requires-Dist: pytest-cov; extra == "test"
|
|
|
34
34
|
Dynamic: license-file
|
|
35
35
|
|
|
36
36
|
# OpenSportsLib
|
|
37
|
+
<img src="docs/assets/osl.jpg" height="400">
|
|
37
38
|
|
|
38
39
|
OpenSportsLib is a modular Python library for sports video understanding.
|
|
39
40
|
|
|
@@ -188,7 +189,7 @@ Minimal localization sample:
|
|
|
188
189
|
```
|
|
189
190
|
|
|
190
191
|
Relative paths in `inputs[].path` are resolved from the split media root in the
|
|
191
|
-
YAML config, for example `DATA.train.
|
|
192
|
+
YAML config, for example `DATA.common.splits.train.source_path`. See the full
|
|
192
193
|
[OSL JSON format guide](docs/data/osl-json-format.md) for field definitions,
|
|
193
194
|
multi-modal examples, prediction payloads, and conversion notes.
|
|
194
195
|
|
|
@@ -345,7 +346,7 @@ Use the README for the fast start, then go deeper through:
|
|
|
345
346
|
- Full documentation: https://opensportslab.github.io/opensportslib/
|
|
346
347
|
- OSL JSON format: [docs/data/osl-json-format.md](docs/data/osl-json-format.md)
|
|
347
348
|
- High-level API guide: [opensportslib/apis/README.md](opensportslib/apis/README.md)
|
|
348
|
-
- Configuration guide: https://opensportslab.github.io/opensportslib/
|
|
349
|
+
- Configuration guide: https://opensportslab.github.io/opensportslib/config/configuration-guide/
|
|
349
350
|
- Example configs: [examples/configs/](examples/configs/)
|
|
350
351
|
- Quickstart scripts: [examples/quickstart/](examples/quickstart/)
|
|
351
352
|
- Contribution guide: [CONTRIBUTING.md](CONTRIBUTING.md)
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
# OpenSportsLib
|
|
2
|
+
<img src="docs/assets/osl.jpg" height="400">
|
|
2
3
|
|
|
3
4
|
OpenSportsLib is a modular Python library for sports video understanding.
|
|
4
5
|
|
|
@@ -153,7 +154,7 @@ Minimal localization sample:
|
|
|
153
154
|
```
|
|
154
155
|
|
|
155
156
|
Relative paths in `inputs[].path` are resolved from the split media root in the
|
|
156
|
-
YAML config, for example `DATA.train.
|
|
157
|
+
YAML config, for example `DATA.common.splits.train.source_path`. See the full
|
|
157
158
|
[OSL JSON format guide](docs/data/osl-json-format.md) for field definitions,
|
|
158
159
|
multi-modal examples, prediction payloads, and conversion notes.
|
|
159
160
|
|
|
@@ -310,7 +311,7 @@ Use the README for the fast start, then go deeper through:
|
|
|
310
311
|
- Full documentation: https://opensportslab.github.io/opensportslib/
|
|
311
312
|
- OSL JSON format: [docs/data/osl-json-format.md](docs/data/osl-json-format.md)
|
|
312
313
|
- High-level API guide: [opensportslib/apis/README.md](opensportslib/apis/README.md)
|
|
313
|
-
- Configuration guide: https://opensportslab.github.io/opensportslib/
|
|
314
|
+
- Configuration guide: https://opensportslab.github.io/opensportslib/config/configuration-guide/
|
|
314
315
|
- Example configs: [examples/configs/](examples/configs/)
|
|
315
316
|
- Quickstart scripts: [examples/quickstart/](examples/quickstart/)
|
|
316
317
|
- Contribution guide: [CONTRIBUTING.md](CONTRIBUTING.md)
|
|
@@ -9,6 +9,7 @@ import uuid
|
|
|
9
9
|
from abc import ABC, abstractmethod
|
|
10
10
|
from typing import Any
|
|
11
11
|
|
|
12
|
+
from opensportslib.core.config.accessors import get_component_name_by_kind
|
|
12
13
|
from opensportslib.core.utils.config import expand, load_config_omega, fetch_and_merge_config_from_HF
|
|
13
14
|
|
|
14
15
|
|
|
@@ -23,6 +24,8 @@ class BaseTaskModel(ABC):
|
|
|
23
24
|
|
|
24
25
|
self.config_path = expand(config)
|
|
25
26
|
self.config = load_config_omega(self.config_path)
|
|
27
|
+
self.last_loaded_weights = None
|
|
28
|
+
self.best_checkpoint = None
|
|
26
29
|
|
|
27
30
|
if weights is not None:
|
|
28
31
|
self.config = fetch_and_merge_config_from_HF(self.config, weights, merge_policy="compatibility")
|
|
@@ -41,15 +44,23 @@ class BaseTaskModel(ABC):
|
|
|
41
44
|
|
|
42
45
|
system_cfg = getattr(self.config, "SYSTEM", None)
|
|
43
46
|
if system_cfg is not None:
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
47
|
+
system_paths = getattr(system_cfg, "paths", None)
|
|
48
|
+
base_save_dir = expand(
|
|
49
|
+
getattr(system_paths, "save_dir", None)
|
|
50
|
+
or getattr(system_cfg, "save_dir", None)
|
|
51
|
+
or "./checkpoints"
|
|
52
|
+
)
|
|
53
|
+
model_name = get_component_name_by_kind(self.config, "encoder") or "model"
|
|
48
54
|
run_save_dir = os.path.join(base_save_dir, model_name, self.run_id)
|
|
49
55
|
self.save_dir = run_save_dir
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
56
|
+
if system_paths is not None:
|
|
57
|
+
system_paths.save_dir = run_save_dir
|
|
58
|
+
if hasattr(system_paths, "work_dir"):
|
|
59
|
+
system_paths.work_dir = run_save_dir
|
|
60
|
+
else:
|
|
61
|
+
system_cfg.save_dir = run_save_dir
|
|
62
|
+
if hasattr(system_cfg, "work_dir"):
|
|
63
|
+
system_cfg.work_dir = run_save_dir
|
|
53
64
|
os.makedirs(run_save_dir, exist_ok=True)
|
|
54
65
|
else:
|
|
55
66
|
self.save_dir = expand("./checkpoints")
|
|
@@ -60,8 +71,6 @@ class BaseTaskModel(ABC):
|
|
|
60
71
|
self.model = None
|
|
61
72
|
self.processor = None
|
|
62
73
|
self.trainer = None
|
|
63
|
-
self.best_checkpoint = None
|
|
64
|
-
self.last_loaded_weights = None
|
|
65
74
|
|
|
66
75
|
if weights is not None:
|
|
67
76
|
self.load_weights(weights=weights)
|
|
@@ -4,8 +4,17 @@
|
|
|
4
4
|
|
|
5
5
|
import logging
|
|
6
6
|
import os
|
|
7
|
+
import json
|
|
7
8
|
|
|
8
9
|
from opensportslib.apis.base_task_model import BaseTaskModel
|
|
10
|
+
from opensportslib.core.config.accessors import (
|
|
11
|
+
get_component_provider_by_kind,
|
|
12
|
+
get_data_modality,
|
|
13
|
+
get_split_annotation_path,
|
|
14
|
+
get_system_gpu_count,
|
|
15
|
+
get_system_seed,
|
|
16
|
+
get_system_use_seed,
|
|
17
|
+
)
|
|
9
18
|
from opensportslib.core.utils.config import expand
|
|
10
19
|
|
|
11
20
|
class ClassificationModel(BaseTaskModel):
|
|
@@ -15,24 +24,13 @@ class ClassificationModel(BaseTaskModel):
|
|
|
15
24
|
if override is not None:
|
|
16
25
|
return expand(override)
|
|
17
26
|
|
|
18
|
-
|
|
19
|
-
split_cfg = getattr(data_cfg, split, None)
|
|
20
|
-
path = getattr(split_cfg, "path", None) if split_cfg is not None else None
|
|
21
|
-
if path:
|
|
22
|
-
return expand(path)
|
|
23
|
-
|
|
24
|
-
annotations_cfg = getattr(data_cfg, "annotations", None)
|
|
25
|
-
path = (
|
|
26
|
-
getattr(annotations_cfg, split, None)
|
|
27
|
-
if annotations_cfg is not None
|
|
28
|
-
else None
|
|
29
|
-
)
|
|
27
|
+
path = get_split_annotation_path(self.config, split)
|
|
30
28
|
if path:
|
|
31
29
|
return expand(path)
|
|
32
30
|
|
|
33
31
|
raise ValueError(
|
|
34
32
|
f"Could not resolve path for split '{split}'. "
|
|
35
|
-
f"Expected DATA.
|
|
33
|
+
f"Expected DATA.common.splits.{split}.annotation_path."
|
|
36
34
|
)
|
|
37
35
|
|
|
38
36
|
# -----------------------------------------------------------------
|
|
@@ -54,6 +52,7 @@ class ClassificationModel(BaseTaskModel):
|
|
|
54
52
|
):
|
|
55
53
|
"""Execute one training/inference job on a single process."""
|
|
56
54
|
import torch
|
|
55
|
+
import wandb
|
|
57
56
|
from opensportslib.core.trainer.classification_trainer import Trainer_Classification
|
|
58
57
|
from opensportslib.core.utils.ddp import ddp_cleanup, ddp_setup
|
|
59
58
|
from opensportslib.core.utils.wandb import init_wandb
|
|
@@ -77,8 +76,8 @@ class ClassificationModel(BaseTaskModel):
|
|
|
77
76
|
use_wandb=use_wandb,
|
|
78
77
|
)
|
|
79
78
|
|
|
80
|
-
if
|
|
81
|
-
set_reproducibility(config
|
|
79
|
+
if get_system_use_seed(config):
|
|
80
|
+
set_reproducibility(get_system_seed(config))
|
|
82
81
|
|
|
83
82
|
is_ddp = world_size > 1
|
|
84
83
|
if is_ddp:
|
|
@@ -100,32 +99,50 @@ class ClassificationModel(BaseTaskModel):
|
|
|
100
99
|
|
|
101
100
|
trainer.model = model
|
|
102
101
|
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
world_size=world_size,
|
|
112
|
-
)
|
|
113
|
-
if rank == 0 and return_queue is not None:
|
|
114
|
-
best_ckpt = best_ckpt or getattr(trainer.trainer, "best_checkpoint_path", None)
|
|
115
|
-
return_queue.put(best_ckpt)
|
|
116
|
-
|
|
117
|
-
elif mode == "infer":
|
|
118
|
-
test_data = build_dataset(config, test_set, processor, split="test")
|
|
119
|
-
predictions = trainer.infer(
|
|
120
|
-
test_data,
|
|
121
|
-
rank=rank,
|
|
122
|
-
world_size=world_size,
|
|
123
|
-
)
|
|
124
|
-
if rank == 0 and return_queue is not None:
|
|
125
|
-
return_queue.put(predictions)
|
|
102
|
+
modality = get_data_modality(config)
|
|
103
|
+
use_tracking_collate = modality in {"tracking", "tracking_parquet"}
|
|
104
|
+
logging.info(
|
|
105
|
+
"Worker setup | mode=%s | modality=%s | tracking_collate=%s",
|
|
106
|
+
mode,
|
|
107
|
+
modality,
|
|
108
|
+
use_tracking_collate,
|
|
109
|
+
)
|
|
126
110
|
|
|
127
|
-
|
|
128
|
-
|
|
111
|
+
try:
|
|
112
|
+
if mode == "train":
|
|
113
|
+
train_data = build_dataset(config, train_set, processor, split="train")
|
|
114
|
+
valid_data = build_dataset(config, valid_set, processor, split="valid")
|
|
115
|
+
best_ckpt = trainer.train(
|
|
116
|
+
model,
|
|
117
|
+
train_data,
|
|
118
|
+
valid_data,
|
|
119
|
+
rank=rank,
|
|
120
|
+
world_size=world_size,
|
|
121
|
+
)
|
|
122
|
+
if rank == 0 and return_queue is not None:
|
|
123
|
+
best_ckpt = best_ckpt or getattr(trainer.trainer, "best_checkpoint_path", None)
|
|
124
|
+
return_queue.put(best_ckpt)
|
|
125
|
+
|
|
126
|
+
elif mode == "infer":
|
|
127
|
+
test_data = build_dataset(config, test_set, processor, split="test")
|
|
128
|
+
_ = trainer.infer(
|
|
129
|
+
test_data,
|
|
130
|
+
rank=rank,
|
|
131
|
+
world_size=world_size,
|
|
132
|
+
)
|
|
133
|
+
if rank == 0 and return_queue is not None:
|
|
134
|
+
if world_size > 1:
|
|
135
|
+
return_queue.put(getattr(trainer, "predictions_path", None))
|
|
136
|
+
else:
|
|
137
|
+
return_queue.put(getattr(trainer, "predictions_payload", None))
|
|
138
|
+
finally:
|
|
139
|
+
if rank == 0 and use_wandb and getattr(wandb, "run", None) is not None:
|
|
140
|
+
try:
|
|
141
|
+
wandb.finish(quiet=True)
|
|
142
|
+
except Exception:
|
|
143
|
+
pass
|
|
144
|
+
if is_ddp:
|
|
145
|
+
ddp_cleanup()
|
|
129
146
|
|
|
130
147
|
def load_weights(
|
|
131
148
|
self,
|
|
@@ -142,7 +159,7 @@ class ClassificationModel(BaseTaskModel):
|
|
|
142
159
|
loaded = self.trainer.load(weights)
|
|
143
160
|
self.model = loaded[0]
|
|
144
161
|
|
|
145
|
-
if
|
|
162
|
+
if get_component_provider_by_kind(self.config, "encoder") == "huggingface":
|
|
146
163
|
self.processor = loaded[1]
|
|
147
164
|
|
|
148
165
|
self.last_loaded_weights = weights
|
|
@@ -180,11 +197,11 @@ class ClassificationModel(BaseTaskModel):
|
|
|
180
197
|
|
|
181
198
|
effective_weights = weights if weights is not None else self.last_loaded_weights
|
|
182
199
|
|
|
183
|
-
world_size = torch.cuda.device_count() or self.config
|
|
200
|
+
world_size = torch.cuda.device_count() or get_system_gpu_count(self.config)
|
|
184
201
|
use_ddp = use_ddp and world_size > 1
|
|
185
202
|
|
|
186
203
|
ctx = mp.get_context("spawn")
|
|
187
|
-
queue = ctx.
|
|
204
|
+
queue = ctx.SimpleQueue()
|
|
188
205
|
|
|
189
206
|
if use_ddp:
|
|
190
207
|
logging.info(f"Launching DDP on {world_size} GPUs")
|
|
@@ -283,6 +300,9 @@ class ClassificationModel(BaseTaskModel):
|
|
|
283
300
|
)
|
|
284
301
|
|
|
285
302
|
predictions = queue.get()
|
|
303
|
+
if use_ddp and isinstance(predictions, str):
|
|
304
|
+
with open(predictions, encoding="utf-8") as f:
|
|
305
|
+
predictions = json.load(f)
|
|
286
306
|
return predictions
|
|
287
307
|
|
|
288
308
|
def evaluate(
|
|
@@ -3,6 +3,18 @@ import os
|
|
|
3
3
|
import time
|
|
4
4
|
|
|
5
5
|
from opensportslib.apis.base_task_model import BaseTaskModel
|
|
6
|
+
from opensportslib.core.config.accessors import (
|
|
7
|
+
get_data_classes,
|
|
8
|
+
get_loader_backend,
|
|
9
|
+
get_system_gpu_count,
|
|
10
|
+
get_system_seed,
|
|
11
|
+
set_system_path,
|
|
12
|
+
get_train_trainer_type,
|
|
13
|
+
get_train_execution,
|
|
14
|
+
get_split_annotation_path,
|
|
15
|
+
get_split_cfg,
|
|
16
|
+
set_split_annotation_path,
|
|
17
|
+
)
|
|
6
18
|
from opensportslib.core.utils.config import expand
|
|
7
19
|
|
|
8
20
|
|
|
@@ -21,44 +33,33 @@ class LocalizationModel(BaseTaskModel):
|
|
|
21
33
|
if override is not None:
|
|
22
34
|
return expand(override)
|
|
23
35
|
|
|
24
|
-
|
|
25
|
-
split_cfg = getattr(data_cfg, split, None)
|
|
26
|
-
path = getattr(split_cfg, "path", None) if split_cfg is not None else None
|
|
27
|
-
if path:
|
|
28
|
-
return expand(path)
|
|
29
|
-
|
|
30
|
-
annotations_cfg = getattr(data_cfg, "annotations", None)
|
|
31
|
-
path = (
|
|
32
|
-
getattr(annotations_cfg, split, None)
|
|
33
|
-
if annotations_cfg is not None
|
|
34
|
-
else None
|
|
35
|
-
)
|
|
36
|
+
path = get_split_annotation_path(self.config, split)
|
|
36
37
|
if path:
|
|
37
38
|
return expand(path)
|
|
38
39
|
|
|
39
40
|
raise ValueError(
|
|
40
41
|
f"Could not resolve path for split '{split}'. "
|
|
41
|
-
f"Expected DATA.
|
|
42
|
+
f"Expected DATA.common.splits.{split}.annotation_path."
|
|
42
43
|
)
|
|
43
44
|
|
|
44
45
|
def _set_split_path(self, split: str, value: str) -> str:
|
|
45
46
|
resolved = expand(value)
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
47
|
+
set_split_annotation_path(self.config, split, resolved)
|
|
48
|
+
return resolved
|
|
49
|
+
|
|
50
|
+
def _gate_multi_gpu_by_device(self, device) -> None:
|
|
51
|
+
"""Disable TRAIN.execution.multi_gpu when effective device is CPU."""
|
|
52
|
+
execution = getattr(getattr(self.config, "TRAIN", None), "execution", None)
|
|
53
|
+
if execution is None:
|
|
54
|
+
return
|
|
55
|
+
|
|
56
|
+
multi_gpu = bool(getattr(execution, "multi_gpu", False))
|
|
57
|
+
if device.type == "cpu" and multi_gpu:
|
|
58
|
+
execution.multi_gpu = False
|
|
59
|
+
logging.warning(
|
|
60
|
+
"Detected SYSTEM.device=%s; forcing TRAIN.execution.multi_gpu=false for localization runtime.",
|
|
61
|
+
device,
|
|
62
|
+
)
|
|
62
63
|
|
|
63
64
|
def load_weights(
|
|
64
65
|
self,
|
|
@@ -78,13 +79,8 @@ class LocalizationModel(BaseTaskModel):
|
|
|
78
79
|
if weights is None:
|
|
79
80
|
raise ValueError("`weights` must be provided to load_weights().")
|
|
80
81
|
|
|
81
|
-
model_cfg = getattr(self.config, "MODEL", None)
|
|
82
|
-
if not self.train_flag:
|
|
83
|
-
original_multi_gpu = getattr(model_cfg, "multi_gpu", None)
|
|
84
|
-
if model_cfg is not None and original_multi_gpu is not None:
|
|
85
|
-
model_cfg.multi_gpu = False
|
|
86
|
-
|
|
87
82
|
device = select_device(self.config.SYSTEM)
|
|
83
|
+
self._gate_multi_gpu_by_device(device)
|
|
88
84
|
if self.model is None:
|
|
89
85
|
self.model = build_model(self.config, device=device)
|
|
90
86
|
|
|
@@ -93,7 +89,11 @@ class LocalizationModel(BaseTaskModel):
|
|
|
93
89
|
inner_model = getattr(self.model, "model", self.model)
|
|
94
90
|
|
|
95
91
|
if is_local_path(weights):
|
|
96
|
-
|
|
92
|
+
set_system_path(
|
|
93
|
+
self.config,
|
|
94
|
+
"work_dir",
|
|
95
|
+
os.path.dirname(os.path.abspath(weights)),
|
|
96
|
+
)
|
|
97
97
|
|
|
98
98
|
if default_args is not None:
|
|
99
99
|
logging.info("Building optimizer + scaler for checkpoint restore...")
|
|
@@ -135,7 +135,7 @@ class LocalizationModel(BaseTaskModel):
|
|
|
135
135
|
|
|
136
136
|
best_criterion_valid = checkpoint.get(
|
|
137
137
|
"best_criterion_valid",
|
|
138
|
-
0 if self.config.
|
|
138
|
+
0 if get_train_execution(self.config).get("criterion_valid") == "map" else float("inf")
|
|
139
139
|
)
|
|
140
140
|
self._resume_state = {
|
|
141
141
|
"optimizer": optimizer,
|
|
@@ -146,10 +146,6 @@ class LocalizationModel(BaseTaskModel):
|
|
|
146
146
|
"best_criterion_valid": best_criterion_valid,
|
|
147
147
|
}
|
|
148
148
|
|
|
149
|
-
if not self.train_flag:
|
|
150
|
-
if model_cfg is not None and original_multi_gpu is not None:
|
|
151
|
-
model_cfg.multi_gpu = original_multi_gpu
|
|
152
|
-
|
|
153
149
|
def train(
|
|
154
150
|
self,
|
|
155
151
|
train_set=None,
|
|
@@ -178,6 +174,9 @@ class LocalizationModel(BaseTaskModel):
|
|
|
178
174
|
valid_set = self._resolve_split_path("valid", valid_set)
|
|
179
175
|
self._set_split_path("train", train_set)
|
|
180
176
|
self._set_split_path("valid", valid_set)
|
|
177
|
+
# E2E validation mAP uses the `valid_data_frames` split; keep it in sync
|
|
178
|
+
# with explicit valid annotation overrides.
|
|
179
|
+
self._set_split_path("valid_data_frames", valid_set)
|
|
181
180
|
|
|
182
181
|
self.config = resolve_config_omega(self.config, weights=weights)
|
|
183
182
|
check_config(self.config, split="train")
|
|
@@ -203,34 +202,34 @@ class LocalizationModel(BaseTaskModel):
|
|
|
203
202
|
torch.backends.cudnn.benchmark = False
|
|
204
203
|
torch.use_deterministic_algorithms(True, warn_only=True)
|
|
205
204
|
|
|
206
|
-
set_seed(self.config
|
|
205
|
+
set_seed(get_system_seed(self.config))
|
|
207
206
|
|
|
208
207
|
start = time.time()
|
|
209
208
|
|
|
210
209
|
data_obj_train = build_dataset(self.config, split="train")
|
|
211
210
|
dataset_train = data_obj_train.building_dataset(
|
|
212
211
|
cfg=data_obj_train.cfg,
|
|
213
|
-
gpu=self.config
|
|
212
|
+
gpu=get_system_gpu_count(self.config),
|
|
214
213
|
default_args=data_obj_train.default_args,
|
|
215
214
|
)
|
|
216
215
|
train_loader = data_obj_train.building_dataloader(
|
|
217
216
|
dataset_train,
|
|
218
217
|
cfg=data_obj_train.cfg.dataloader,
|
|
219
|
-
gpu=self.config
|
|
220
|
-
dali=self.config
|
|
218
|
+
gpu=get_system_gpu_count(self.config),
|
|
219
|
+
dali=(get_loader_backend(self.config) == "dali"),
|
|
221
220
|
)
|
|
222
221
|
|
|
223
222
|
data_obj_valid = build_dataset(self.config, split="valid")
|
|
224
223
|
dataset_valid = data_obj_valid.building_dataset(
|
|
225
224
|
cfg=data_obj_valid.cfg,
|
|
226
|
-
gpu=self.config
|
|
225
|
+
gpu=get_system_gpu_count(self.config),
|
|
227
226
|
default_args=data_obj_valid.default_args,
|
|
228
227
|
)
|
|
229
228
|
valid_loader = data_obj_valid.building_dataloader(
|
|
230
229
|
dataset_valid,
|
|
231
230
|
cfg=data_obj_valid.cfg.dataloader,
|
|
232
|
-
gpu=self.config
|
|
233
|
-
dali=self.config
|
|
231
|
+
gpu=get_system_gpu_count(self.config),
|
|
232
|
+
dali=(get_loader_backend(self.config) == "dali"),
|
|
234
233
|
)
|
|
235
234
|
|
|
236
235
|
default_args = get_default_args_trainer(self.config, len(train_loader))
|
|
@@ -241,6 +240,7 @@ class LocalizationModel(BaseTaskModel):
|
|
|
241
240
|
self.load_weights(weights=effective_weights, default_args=default_args)
|
|
242
241
|
elif self.model is None:
|
|
243
242
|
device = select_device(self.config.SYSTEM)
|
|
243
|
+
self._gate_multi_gpu_by_device(device)
|
|
244
244
|
self.model = build_model(self.config, device=device)
|
|
245
245
|
|
|
246
246
|
self.trainer = build_trainer(
|
|
@@ -257,8 +257,8 @@ class LocalizationModel(BaseTaskModel):
|
|
|
257
257
|
self.model,
|
|
258
258
|
train_loader,
|
|
259
259
|
valid_loader,
|
|
260
|
-
self.config
|
|
261
|
-
self.config
|
|
260
|
+
get_data_classes(self.config),
|
|
261
|
+
get_train_trainer_type(self.config),
|
|
262
262
|
)
|
|
263
263
|
)
|
|
264
264
|
|
|
@@ -291,10 +291,9 @@ class LocalizationModel(BaseTaskModel):
|
|
|
291
291
|
test_set = self._resolve_split_path("test", test_set)
|
|
292
292
|
self._set_split_path("test", test_set)
|
|
293
293
|
|
|
294
|
-
self.config.MODEL.multi_gpu = False
|
|
295
294
|
self.config = resolve_config_omega(self.config, weights=weights)
|
|
296
295
|
check_config(self.config, split="test")
|
|
297
|
-
self.config.infer_split = whether_infer_split(self.config
|
|
296
|
+
self.config.infer_split = whether_infer_split(get_split_cfg(self.config, "test"))
|
|
298
297
|
|
|
299
298
|
init_wandb(
|
|
300
299
|
self.config_path,
|
|
@@ -315,22 +314,23 @@ class LocalizationModel(BaseTaskModel):
|
|
|
315
314
|
self.load_weights(weights=effective_weights)
|
|
316
315
|
elif self.model is None:
|
|
317
316
|
device = select_device(self.config.SYSTEM)
|
|
317
|
+
self._gate_multi_gpu_by_device(device)
|
|
318
318
|
self.model = build_model(self.config, device=device)
|
|
319
319
|
|
|
320
320
|
data_obj_test = build_dataset(self.config, split="test")
|
|
321
321
|
dataset_test = data_obj_test.building_dataset(
|
|
322
322
|
cfg=data_obj_test.cfg,
|
|
323
|
-
gpu=self.config
|
|
323
|
+
gpu=get_system_gpu_count(self.config),
|
|
324
324
|
default_args=data_obj_test.default_args,
|
|
325
325
|
)
|
|
326
326
|
test_loader = data_obj_test.building_dataloader(
|
|
327
327
|
dataset_test,
|
|
328
328
|
cfg=data_obj_test.cfg.dataloader,
|
|
329
|
-
gpu=self.config
|
|
330
|
-
dali=self.config
|
|
329
|
+
gpu=get_system_gpu_count(self.config),
|
|
330
|
+
dali=(get_loader_backend(self.config) == "dali"),
|
|
331
331
|
)
|
|
332
332
|
|
|
333
|
-
inferer = build_inferer(cfg=self.config
|
|
333
|
+
inferer = build_inferer(cfg=self.config, model=self.model)
|
|
334
334
|
predictions = inferer.infer(
|
|
335
335
|
cfg=self.config,
|
|
336
336
|
data=dataset_test,
|
|
@@ -361,10 +361,9 @@ class LocalizationModel(BaseTaskModel):
|
|
|
361
361
|
|
|
362
362
|
test_set = self._resolve_split_path("test", test_set)
|
|
363
363
|
self._set_split_path("test", test_set)
|
|
364
|
-
self.config.MODEL.multi_gpu = False
|
|
365
364
|
self.config = resolve_config_omega(self.config, weights=weights)
|
|
366
365
|
check_config(self.config, split="test")
|
|
367
|
-
self.config.infer_split = whether_infer_split(self.config
|
|
366
|
+
self.config.infer_split = whether_infer_split(get_split_cfg(self.config, "test"))
|
|
368
367
|
|
|
369
368
|
init_wandb(
|
|
370
369
|
self.config_path,
|
|
@@ -382,16 +381,17 @@ class LocalizationModel(BaseTaskModel):
|
|
|
382
381
|
|
|
383
382
|
metrics = None
|
|
384
383
|
|
|
385
|
-
|
|
384
|
+
test_path = get_split_annotation_path(self.config, "test")
|
|
385
|
+
if has_localization_events(test_path):
|
|
386
386
|
logging.info("Ground truth labels detected -> running evaluation")
|
|
387
387
|
evaluator = build_evaluator(cfg=self.config)
|
|
388
388
|
eval_input = (
|
|
389
|
-
self.config
|
|
389
|
+
getattr(get_split_cfg(self.config, "test"), "results")
|
|
390
390
|
if isinstance(predictions, dict)
|
|
391
391
|
else predictions
|
|
392
392
|
)
|
|
393
393
|
metrics = evaluator.evaluate(
|
|
394
|
-
cfg_testset=self.config
|
|
394
|
+
cfg_testset=get_split_cfg(self.config, "test"),
|
|
395
395
|
json_gz_file=eval_input,
|
|
396
396
|
)
|
|
397
397
|
else:
|