dslighting 1.7.1__py3-none-any.whl → 1.7.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dslighting/__init__.py +1 -1
- dslighting/core/agent.py +78 -62
- {dslighting-1.7.1.dist-info → dslighting-1.7.6.dist-info}/METADATA +1 -1
- {dslighting-1.7.1.dist-info → dslighting-1.7.6.dist-info}/RECORD +352 -7
- {dslighting-1.7.1.dist-info → dslighting-1.7.6.dist-info}/top_level.txt +1 -0
- mlebench/README.md +39 -0
- mlebench/__init__.py +0 -0
- mlebench/cli.py +221 -0
- mlebench/competitions/3d-object-detection-for-autonomous-vehicles/grade.py +161 -0
- mlebench/competitions/3d-object-detection-for-autonomous-vehicles/mAP_evaluation.py +425 -0
- mlebench/competitions/3d-object-detection-for-autonomous-vehicles/prepare.py +483 -0
- mlebench/competitions/3d-object-detection-for-autonomous-vehicles/prepare_val.py +719 -0
- mlebench/competitions/AI4Code/grade.py +70 -0
- mlebench/competitions/AI4Code/prepare.py +84 -0
- mlebench/competitions/AI4Code/prepare_val.py +159 -0
- mlebench/competitions/__init__.py +0 -0
- mlebench/competitions/aerial-cactus-identification/grade.py +11 -0
- mlebench/competitions/aerial-cactus-identification/prepare.py +71 -0
- mlebench/competitions/aerial-cactus-identification/prepare_val.py +133 -0
- mlebench/competitions/alaska2-image-steganalysis/grade.py +136 -0
- mlebench/competitions/alaska2-image-steganalysis/prepare.py +88 -0
- mlebench/competitions/alaska2-image-steganalysis/prepare_val.py +148 -0
- mlebench/competitions/aptos2019-blindness-detection/grade.py +35 -0
- mlebench/competitions/aptos2019-blindness-detection/prepare.py +75 -0
- mlebench/competitions/aptos2019-blindness-detection/prepare_val.py +123 -0
- mlebench/competitions/bike-sharing-demand/__init__.py +0 -0
- mlebench/competitions/bike-sharing-demand/grade.py +55 -0
- mlebench/competitions/bike-sharing-demand/prepare.py +37 -0
- mlebench/competitions/billion-word-imputation/grade.py +37 -0
- mlebench/competitions/billion-word-imputation/prepare.py +107 -0
- mlebench/competitions/billion-word-imputation/prepare_val.py +179 -0
- mlebench/competitions/bms-molecular-translation/grade.py +40 -0
- mlebench/competitions/bms-molecular-translation/prepare.py +68 -0
- mlebench/competitions/bms-molecular-translation/prepare_val.py +131 -0
- mlebench/competitions/cassava-leaf-disease-classification/grade.py +12 -0
- mlebench/competitions/cassava-leaf-disease-classification/prepare.py +113 -0
- mlebench/competitions/cassava-leaf-disease-classification/prepare_val.py +186 -0
- mlebench/competitions/cdiscount-image-classification-challenge/grade.py +11 -0
- mlebench/competitions/cdiscount-image-classification-challenge/prepare.py +144 -0
- mlebench/competitions/cdiscount-image-classification-challenge/prepare_val.py +205 -0
- mlebench/competitions/chaii-hindi-and-tamil-question-answering/grade.py +67 -0
- mlebench/competitions/chaii-hindi-and-tamil-question-answering/prepare.py +31 -0
- mlebench/competitions/chaii-hindi-and-tamil-question-answering/prepare_val.py +94 -0
- mlebench/competitions/champs-scalar-coupling/grade.py +60 -0
- mlebench/competitions/champs-scalar-coupling/prepare.py +116 -0
- mlebench/competitions/champs-scalar-coupling/prepare_val.py +155 -0
- mlebench/competitions/conways-reverse-game-of-life-2020/__init__.py +0 -0
- mlebench/competitions/conways-reverse-game-of-life-2020/grade.py +40 -0
- mlebench/competitions/conways-reverse-game-of-life-2020/prepare.py +41 -0
- mlebench/competitions/demand-forecasting-kernels-only/__init__.py +0 -0
- mlebench/competitions/demand-forecasting-kernels-only/grade.py +66 -0
- mlebench/competitions/demand-forecasting-kernels-only/prepare.py +27 -0
- mlebench/competitions/demand_forecasting_kernels_only/__init__.py +0 -0
- mlebench/competitions/demand_forecasting_kernels_only/grade.py +66 -0
- mlebench/competitions/demand_forecasting_kernels_only/prepare.py +27 -0
- mlebench/competitions/denoising-dirty-documents/grade.py +44 -0
- mlebench/competitions/denoising-dirty-documents/prepare.py +134 -0
- mlebench/competitions/denoising-dirty-documents/prepare_val.py +178 -0
- mlebench/competitions/detecting-insults-in-social-commentary/grade.py +11 -0
- mlebench/competitions/detecting-insults-in-social-commentary/prepare.py +72 -0
- mlebench/competitions/detecting-insults-in-social-commentary/prepare_val.py +128 -0
- mlebench/competitions/dog-breed-identification/dogs.py +124 -0
- mlebench/competitions/dog-breed-identification/grade.py +42 -0
- mlebench/competitions/dog-breed-identification/prepare.py +55 -0
- mlebench/competitions/dog-breed-identification/prepare_val.py +104 -0
- mlebench/competitions/dogs-vs-cats-redux-kernels-edition/grade.py +43 -0
- mlebench/competitions/dogs-vs-cats-redux-kernels-edition/prepare.py +70 -0
- mlebench/competitions/dogs-vs-cats-redux-kernels-edition/prepare_val.py +143 -0
- mlebench/competitions/ethanol-concentration/grade.py +23 -0
- mlebench/competitions/ethanol-concentration/prepare.py +90 -0
- mlebench/competitions/facebook-recruiting-iii-keyword-extraction/grade.py +60 -0
- mlebench/competitions/facebook-recruiting-iii-keyword-extraction/prepare.py +41 -0
- mlebench/competitions/facebook-recruiting-iii-keyword-extraction/prepare_val.py +92 -0
- mlebench/competitions/feedback-prize-english-language-learning/__init__.py +0 -0
- mlebench/competitions/feedback-prize-english-language-learning/grade.py +60 -0
- mlebench/competitions/feedback-prize-english-language-learning/prepare.py +39 -0
- mlebench/competitions/freesound-audio-tagging-2019/grade.py +64 -0
- mlebench/competitions/freesound-audio-tagging-2019/prepare.py +94 -0
- mlebench/competitions/freesound-audio-tagging-2019/prepare_val.py +175 -0
- mlebench/competitions/freesound-audio-tagging-2019/vocabulary.py +83 -0
- mlebench/competitions/google-quest-challenge/classes.py +32 -0
- mlebench/competitions/google-quest-challenge/grade.py +45 -0
- mlebench/competitions/google-quest-challenge/prepare.py +58 -0
- mlebench/competitions/google-quest-challenge/prepare_val.py +120 -0
- mlebench/competitions/google-research-identify-contrails-reduce-global-warming/grade.py +77 -0
- mlebench/competitions/google-research-identify-contrails-reduce-global-warming/prepare.py +155 -0
- mlebench/competitions/google-research-identify-contrails-reduce-global-warming/prepare_val.py +211 -0
- mlebench/competitions/h-and-m-personalized-fashion-recommendations/grade.py +42 -0
- mlebench/competitions/h-and-m-personalized-fashion-recommendations/prepare.py +102 -0
- mlebench/competitions/h-and-m-personalized-fashion-recommendations/prepare_val.py +132 -0
- mlebench/competitions/handwriting/grade.py +23 -0
- mlebench/competitions/handwriting/prepare.py +179 -0
- mlebench/competitions/herbarium-2020-fgvc7/grade.py +34 -0
- mlebench/competitions/herbarium-2020-fgvc7/prepare.py +251 -0
- mlebench/competitions/herbarium-2020-fgvc7/prepare_val.py +242 -0
- mlebench/competitions/herbarium-2021-fgvc8/grade.py +34 -0
- mlebench/competitions/herbarium-2021-fgvc8/prepare.py +251 -0
- mlebench/competitions/herbarium-2021-fgvc8/prepare_val.py +222 -0
- mlebench/competitions/herbarium-2022-fgvc9/grade.py +31 -0
- mlebench/competitions/herbarium-2022-fgvc9/prepare.py +233 -0
- mlebench/competitions/herbarium-2022-fgvc9/prepare_val.py +213 -0
- mlebench/competitions/histopathologic-cancer-detection/grade.py +12 -0
- mlebench/competitions/histopathologic-cancer-detection/prepare.py +59 -0
- mlebench/competitions/histopathologic-cancer-detection/prepare_val.py +131 -0
- mlebench/competitions/hms-harmful-brain-activity-classification/constants.py +9 -0
- mlebench/competitions/hms-harmful-brain-activity-classification/grade.py +43 -0
- mlebench/competitions/hms-harmful-brain-activity-classification/kaggle_metric_utilities.py +96 -0
- mlebench/competitions/hms-harmful-brain-activity-classification/kullback_leibler_divergence.py +118 -0
- mlebench/competitions/hms-harmful-brain-activity-classification/prepare.py +121 -0
- mlebench/competitions/hms-harmful-brain-activity-classification/prepare_val.py +190 -0
- mlebench/competitions/hotel-id-2021-fgvc8/grade.py +41 -0
- mlebench/competitions/hotel-id-2021-fgvc8/prepare.py +63 -0
- mlebench/competitions/hotel-id-2021-fgvc8/prepare_val.py +132 -0
- mlebench/competitions/hubmap-kidney-segmentation/grade.py +62 -0
- mlebench/competitions/hubmap-kidney-segmentation/prepare.py +108 -0
- mlebench/competitions/hubmap-kidney-segmentation/prepare_val.py +153 -0
- mlebench/competitions/icecube-neutrinos-in-deep-ice/grade.py +111 -0
- mlebench/competitions/icecube-neutrinos-in-deep-ice/prepare.py +127 -0
- mlebench/competitions/icecube-neutrinos-in-deep-ice/prepare_val.py +183 -0
- mlebench/competitions/ili/grade.py +60 -0
- mlebench/competitions/ili/prepare.py +99 -0
- mlebench/competitions/imet-2020-fgvc7/grade.py +54 -0
- mlebench/competitions/imet-2020-fgvc7/prepare.py +77 -0
- mlebench/competitions/imet-2020-fgvc7/prepare_val.py +157 -0
- mlebench/competitions/inaturalist-2019-fgvc6/grade.py +35 -0
- mlebench/competitions/inaturalist-2019-fgvc6/prepare.py +259 -0
- mlebench/competitions/inaturalist-2019-fgvc6/prepare_val.py +304 -0
- mlebench/competitions/instant-gratification/__init__.py +0 -0
- mlebench/competitions/instant-gratification/grade.py +55 -0
- mlebench/competitions/instant-gratification/prepare.py +25 -0
- mlebench/competitions/instant_gratification/__init__.py +0 -0
- mlebench/competitions/instant_gratification/grade.py +55 -0
- mlebench/competitions/instant_gratification/prepare.py +25 -0
- mlebench/competitions/invasive-species-monitoring/grade.py +11 -0
- mlebench/competitions/invasive-species-monitoring/prepare.py +97 -0
- mlebench/competitions/invasive-species-monitoring/prepare_val.py +164 -0
- mlebench/competitions/iwildcam-2019-fgvc6/grade.py +44 -0
- mlebench/competitions/iwildcam-2019-fgvc6/prepare.py +118 -0
- mlebench/competitions/iwildcam-2019-fgvc6/prepare_val.py +194 -0
- mlebench/competitions/iwildcam-2020-fgvc7/grade.py +11 -0
- mlebench/competitions/iwildcam-2020-fgvc7/prepare.py +164 -0
- mlebench/competitions/iwildcam-2020-fgvc7/prepare_val.py +245 -0
- mlebench/competitions/jigsaw-toxic-comment-classification-challenge/classes.py +1 -0
- mlebench/competitions/jigsaw-toxic-comment-classification-challenge/grade.py +54 -0
- mlebench/competitions/jigsaw-toxic-comment-classification-challenge/prepare.py +42 -0
- mlebench/competitions/jigsaw-toxic-comment-classification-challenge/prepare_val.py +88 -0
- mlebench/competitions/jigsaw-unintended-bias-in-toxicity-classification/grade.py +153 -0
- mlebench/competitions/jigsaw-unintended-bias-in-toxicity-classification/prepare.py +36 -0
- mlebench/competitions/jigsaw-unintended-bias-in-toxicity-classification/prepare_val.py +117 -0
- mlebench/competitions/kuzushiji-recognition/grade.py +58 -0
- mlebench/competitions/kuzushiji-recognition/kuzushiji_metric.py +118 -0
- mlebench/competitions/kuzushiji-recognition/prepare.py +92 -0
- mlebench/competitions/kuzushiji-recognition/prepare_val.py +149 -0
- mlebench/competitions/leaf-classification/classes.py +101 -0
- mlebench/competitions/leaf-classification/grade.py +44 -0
- mlebench/competitions/leaf-classification/prepare.py +60 -0
- mlebench/competitions/leaf-classification/prepare_val.py +116 -0
- mlebench/competitions/learning-agency-lab-automated-essay-scoring-2/grade.py +44 -0
- mlebench/competitions/learning-agency-lab-automated-essay-scoring-2/prepare.py +51 -0
- mlebench/competitions/learning-agency-lab-automated-essay-scoring-2/prepare_val.py +96 -0
- mlebench/competitions/liverpool-ion-switching/__init__.py +0 -0
- mlebench/competitions/liverpool-ion-switching/grade.py +52 -0
- mlebench/competitions/liverpool-ion-switching/prepare.py +27 -0
- mlebench/competitions/liverpool_ion_switching/__init__.py +0 -0
- mlebench/competitions/liverpool_ion_switching/grade.py +52 -0
- mlebench/competitions/liverpool_ion_switching/prepare.py +27 -0
- mlebench/competitions/lmsys-chatbot-arena/grade.py +63 -0
- mlebench/competitions/lmsys-chatbot-arena/prepare.py +52 -0
- mlebench/competitions/lmsys-chatbot-arena/prepare_val.py +115 -0
- mlebench/competitions/mcm_2024_c_test/grade.py +107 -0
- mlebench/competitions/mcm_2024_c_test/prepare.py +2 -0
- mlebench/competitions/ml2021spring-hw2/grade.py +11 -0
- mlebench/competitions/ml2021spring-hw2/prepare.py +58 -0
- mlebench/competitions/ml2021spring-hw2/prepare_val.py +135 -0
- mlebench/competitions/mlsp-2013-birds/grade.py +11 -0
- mlebench/competitions/mlsp-2013-birds/prepare.py +182 -0
- mlebench/competitions/mlsp-2013-birds/prepare_val.py +241 -0
- mlebench/competitions/movie-review-sentiment-analysis-kernels-only/grade.py +11 -0
- mlebench/competitions/movie-review-sentiment-analysis-kernels-only/prepare.py +58 -0
- mlebench/competitions/movie-review-sentiment-analysis-kernels-only/prepare_val.py +120 -0
- mlebench/competitions/multi-modal-gesture-recognition/grade.py +58 -0
- mlebench/competitions/multi-modal-gesture-recognition/prepare.py +85 -0
- mlebench/competitions/multi-modal-gesture-recognition/prepare_val.py +139 -0
- mlebench/competitions/my-custom-task-01/prepare.py +2 -0
- mlebench/competitions/new-my-task-01/prepare.py +2 -0
- mlebench/competitions/new-my-task-03/grade.py +107 -0
- mlebench/competitions/new-my-task-03/prepare.py +2 -0
- mlebench/competitions/new-york-city-taxi-fare-prediction/grade.py +28 -0
- mlebench/competitions/new-york-city-taxi-fare-prediction/prepare.py +44 -0
- mlebench/competitions/new-york-city-taxi-fare-prediction/prepare_val.py +89 -0
- mlebench/competitions/nfl-player-contact-detection/grade.py +36 -0
- mlebench/competitions/nfl-player-contact-detection/prepare.py +101 -0
- mlebench/competitions/nfl-player-contact-detection/prepare_val.py +186 -0
- mlebench/competitions/nomad2018-predict-transparent-conductors/grade.py +47 -0
- mlebench/competitions/nomad2018-predict-transparent-conductors/prepare.py +77 -0
- mlebench/competitions/nomad2018-predict-transparent-conductors/prepare_val.py +144 -0
- mlebench/competitions/osic-pulmonary-fibrosis-progression/grade.py +74 -0
- mlebench/competitions/osic-pulmonary-fibrosis-progression/prepare.py +95 -0
- mlebench/competitions/osic-pulmonary-fibrosis-progression/prepare_val.py +167 -0
- mlebench/competitions/paddy-disease-classification/grade.py +35 -0
- mlebench/competitions/paddy-disease-classification/prepare.py +69 -0
- mlebench/competitions/paddy-disease-classification/prepare_val.py +122 -0
- mlebench/competitions/petfinder-pawpularity-score/grade.py +41 -0
- mlebench/competitions/petfinder-pawpularity-score/prepare.py +76 -0
- mlebench/competitions/petfinder-pawpularity-score/prepare_val.py +154 -0
- mlebench/competitions/plant-pathology-2020-fgvc7/grade.py +41 -0
- mlebench/competitions/plant-pathology-2020-fgvc7/prepare.py +74 -0
- mlebench/competitions/plant-pathology-2020-fgvc7/prepare_val.py +160 -0
- mlebench/competitions/plant-pathology-2021-fgvc8/grade.py +54 -0
- mlebench/competitions/plant-pathology-2021-fgvc8/prepare.py +65 -0
- mlebench/competitions/plant-pathology-2021-fgvc8/prepare_val.py +130 -0
- mlebench/competitions/plant-seedlings-classification/grade.py +39 -0
- mlebench/competitions/plant-seedlings-classification/prepare.py +91 -0
- mlebench/competitions/plant-seedlings-classification/prepare_val.py +158 -0
- mlebench/competitions/playground-series-s3e1/__init__.py +0 -0
- mlebench/competitions/playground-series-s3e1/grade.py +52 -0
- mlebench/competitions/playground-series-s3e1/prepare.py +25 -0
- mlebench/competitions/playground-series-s3e11/__init__.py +0 -0
- mlebench/competitions/playground-series-s3e11/grade.py +55 -0
- mlebench/competitions/playground-series-s3e11/prepare.py +25 -0
- mlebench/competitions/playground-series-s3e18/grade.py +39 -0
- mlebench/competitions/playground-series-s3e18/prepare.py +36 -0
- mlebench/competitions/playground-series-s3e18/prepare_val.py +89 -0
- mlebench/competitions/playground_series_s3e1/__init__.py +0 -0
- mlebench/competitions/playground_series_s3e1/grade.py +52 -0
- mlebench/competitions/playground_series_s3e1/prepare.py +25 -0
- mlebench/competitions/playground_series_s3e11/__init__.py +0 -0
- mlebench/competitions/playground_series_s3e11/grade.py +55 -0
- mlebench/competitions/playground_series_s3e11/prepare.py +25 -0
- mlebench/competitions/predict-volcanic-eruptions-ingv-oe/grade.py +44 -0
- mlebench/competitions/predict-volcanic-eruptions-ingv-oe/prepare.py +68 -0
- mlebench/competitions/predict-volcanic-eruptions-ingv-oe/prepare_val.py +146 -0
- mlebench/competitions/random-acts-of-pizza/grade.py +14 -0
- mlebench/competitions/random-acts-of-pizza/prepare.py +80 -0
- mlebench/competitions/random-acts-of-pizza/prepare_val.py +144 -0
- mlebench/competitions/ranzcr-clip-catheter-line-classification/classes.py +11 -0
- mlebench/competitions/ranzcr-clip-catheter-line-classification/grade.py +31 -0
- mlebench/competitions/ranzcr-clip-catheter-line-classification/prepare.py +53 -0
- mlebench/competitions/ranzcr-clip-catheter-line-classification/prepare_val.py +113 -0
- mlebench/competitions/rsna-2022-cervical-spine-fracture-detection/grade.py +124 -0
- mlebench/competitions/rsna-2022-cervical-spine-fracture-detection/prepare.py +219 -0
- mlebench/competitions/rsna-2022-cervical-spine-fracture-detection/prepare_val.py +257 -0
- mlebench/competitions/rsna-breast-cancer-detection/grade.py +65 -0
- mlebench/competitions/rsna-breast-cancer-detection/prepare.py +141 -0
- mlebench/competitions/rsna-breast-cancer-detection/prepare_val.py +201 -0
- mlebench/competitions/rsna-miccai-brain-tumor-radiogenomic-classification/grade.py +13 -0
- mlebench/competitions/rsna-miccai-brain-tumor-radiogenomic-classification/prepare.py +47 -0
- mlebench/competitions/rsna-miccai-brain-tumor-radiogenomic-classification/prepare_val.py +97 -0
- mlebench/competitions/santander-customer-satisfaction/grade.py +10 -0
- mlebench/competitions/santander-customer-satisfaction/prepare.py +41 -0
- mlebench/competitions/sciencebench-001-clintox-nn/__init__.py +0 -0
- mlebench/competitions/sciencebench-001-clintox-nn/grade.py +56 -0
- mlebench/competitions/sciencebench-001-clintox-nn/prepare.py +75 -0
- mlebench/competitions/sciencebench-015-aai/grade.py +37 -0
- mlebench/competitions/sciencebench-015-aai/prepare.py +102 -0
- mlebench/competitions/sciencebench-051-brain-blood-qsar/grade.py +58 -0
- mlebench/competitions/sciencebench-051-brain-blood-qsar/prepare.py +69 -0
- mlebench/competitions/sciencebench-101-experimental-band-gap-prediction/grade.py +55 -0
- mlebench/competitions/sciencebench-101-experimental-band-gap-prediction/prepare.py +88 -0
- mlebench/competitions/see-click-predict-fix/__init__.py +0 -0
- mlebench/competitions/see-click-predict-fix/grade.py +66 -0
- mlebench/competitions/see-click-predict-fix/prepare.py +25 -0
- mlebench/competitions/see_click_predict_fix/__init__.py +0 -0
- mlebench/competitions/see_click_predict_fix/grade.py +66 -0
- mlebench/competitions/see_click_predict_fix/prepare.py +25 -0
- mlebench/competitions/seti-breakthrough-listen/grade.py +11 -0
- mlebench/competitions/seti-breakthrough-listen/prepare.py +71 -0
- mlebench/competitions/seti-breakthrough-listen/prepare_val.py +159 -0
- mlebench/competitions/siim-covid19-detection/grade.py +194 -0
- mlebench/competitions/siim-covid19-detection/prepare.py +123 -0
- mlebench/competitions/siim-covid19-detection/prepare_val.py +164 -0
- mlebench/competitions/siim-isic-melanoma-classification/grade.py +11 -0
- mlebench/competitions/siim-isic-melanoma-classification/prepare.py +127 -0
- mlebench/competitions/siim-isic-melanoma-classification/prepare_val.py +158 -0
- mlebench/competitions/smartphone-decimeter-2022/grade.py +55 -0
- mlebench/competitions/smartphone-decimeter-2022/notebook.py +86 -0
- mlebench/competitions/smartphone-decimeter-2022/prepare.py +143 -0
- mlebench/competitions/smartphone-decimeter-2022/prepare_val.py +199 -0
- mlebench/competitions/spaceship-titanic/grade.py +11 -0
- mlebench/competitions/spaceship-titanic/prepare.py +23 -0
- mlebench/competitions/spaceship-titanic/prepare_val.py +61 -0
- mlebench/competitions/spooky-author-identification/classes.py +1 -0
- mlebench/competitions/spooky-author-identification/grade.py +38 -0
- mlebench/competitions/spooky-author-identification/prepare.py +40 -0
- mlebench/competitions/spooky-author-identification/prepare_val.py +78 -0
- mlebench/competitions/stanford-covid-vaccine/grade.py +65 -0
- mlebench/competitions/stanford-covid-vaccine/prepare.py +129 -0
- mlebench/competitions/stanford-covid-vaccine/prepare_val.py +199 -0
- mlebench/competitions/statoil-iceberg-classifier-challenge/grade.py +41 -0
- mlebench/competitions/statoil-iceberg-classifier-challenge/prepare.py +105 -0
- mlebench/competitions/statoil-iceberg-classifier-challenge/prepare_val.py +157 -0
- mlebench/competitions/tabular-playground-series-dec-2021/grade.py +11 -0
- mlebench/competitions/tabular-playground-series-dec-2021/prepare.py +39 -0
- mlebench/competitions/tabular-playground-series-dec-2021/prepare_val.py +99 -0
- mlebench/competitions/tabular-playground-series-may-2022/grade.py +9 -0
- mlebench/competitions/tabular-playground-series-may-2022/prepare.py +56 -0
- mlebench/competitions/tabular-playground-series-may-2022/prepare_val.py +116 -0
- mlebench/competitions/tensorflow-speech-recognition-challenge/grade.py +11 -0
- mlebench/competitions/tensorflow-speech-recognition-challenge/prepare.py +90 -0
- mlebench/competitions/tensorflow-speech-recognition-challenge/prepare_val.py +148 -0
- mlebench/competitions/tensorflow2-question-answering/grade.py +122 -0
- mlebench/competitions/tensorflow2-question-answering/prepare.py +122 -0
- mlebench/competitions/tensorflow2-question-answering/prepare_val.py +187 -0
- mlebench/competitions/text-normalization-challenge-english-language/grade.py +49 -0
- mlebench/competitions/text-normalization-challenge-english-language/prepare.py +115 -0
- mlebench/competitions/text-normalization-challenge-english-language/prepare_val.py +213 -0
- mlebench/competitions/text-normalization-challenge-russian-language/grade.py +49 -0
- mlebench/competitions/text-normalization-challenge-russian-language/prepare.py +113 -0
- mlebench/competitions/text-normalization-challenge-russian-language/prepare_val.py +165 -0
- mlebench/competitions/tgs-salt-identification-challenge/grade.py +144 -0
- mlebench/competitions/tgs-salt-identification-challenge/prepare.py +158 -0
- mlebench/competitions/tgs-salt-identification-challenge/prepare_val.py +166 -0
- mlebench/competitions/the-icml-2013-whale-challenge-right-whale-redux/grade.py +11 -0
- mlebench/competitions/the-icml-2013-whale-challenge-right-whale-redux/prepare.py +95 -0
- mlebench/competitions/the-icml-2013-whale-challenge-right-whale-redux/prepare_val.py +141 -0
- mlebench/competitions/tmdb-box-office-prediction/__init__.py +0 -0
- mlebench/competitions/tmdb-box-office-prediction/grade.py +55 -0
- mlebench/competitions/tmdb-box-office-prediction/prepare.py +35 -0
- mlebench/competitions/tweet-sentiment-extraction/grade.py +67 -0
- mlebench/competitions/tweet-sentiment-extraction/prepare.py +36 -0
- mlebench/competitions/tweet-sentiment-extraction/prepare_val.py +106 -0
- mlebench/competitions/us-patent-phrase-to-phrase-matching/grade.py +31 -0
- mlebench/competitions/us-patent-phrase-to-phrase-matching/prepare.py +33 -0
- mlebench/competitions/us-patent-phrase-to-phrase-matching/prepare_val.py +71 -0
- mlebench/competitions/utils.py +266 -0
- mlebench/competitions/uw-madison-gi-tract-image-segmentation/grade.py +158 -0
- mlebench/competitions/uw-madison-gi-tract-image-segmentation/prepare.py +139 -0
- mlebench/competitions/uw-madison-gi-tract-image-segmentation/prepare_val.py +193 -0
- mlebench/competitions/ventilator-pressure-prediction/__init__.py +0 -0
- mlebench/competitions/ventilator-pressure-prediction/grade.py +52 -0
- mlebench/competitions/ventilator-pressure-prediction/prepare.py +27 -0
- mlebench/competitions/ventilator-pressure-prediction/prepare_val.py +142 -0
- mlebench/competitions/ventilator_pressure_prediction/__init__.py +0 -0
- mlebench/competitions/ventilator_pressure_prediction/grade.py +52 -0
- mlebench/competitions/ventilator_pressure_prediction/prepare.py +27 -0
- mlebench/competitions/vesuvius-challenge-ink-detection/grade.py +97 -0
- mlebench/competitions/vesuvius-challenge-ink-detection/prepare.py +122 -0
- mlebench/competitions/vesuvius-challenge-ink-detection/prepare_val.py +170 -0
- mlebench/competitions/vinbigdata-chest-xray-abnormalities-detection/grade.py +220 -0
- mlebench/competitions/vinbigdata-chest-xray-abnormalities-detection/prepare.py +129 -0
- mlebench/competitions/vinbigdata-chest-xray-abnormalities-detection/prepare_val.py +204 -0
- mlebench/competitions/whale-categorization-playground/grade.py +41 -0
- mlebench/competitions/whale-categorization-playground/prepare.py +103 -0
- mlebench/competitions/whale-categorization-playground/prepare_val.py +196 -0
- mlebench/data.py +420 -0
- mlebench/grade.py +209 -0
- mlebench/grade_helpers.py +235 -0
- mlebench/metrics.py +75 -0
- mlebench/registry.py +332 -0
- mlebench/utils.py +346 -0
- {dslighting-1.7.1.dist-info → dslighting-1.7.6.dist-info}/WHEEL +0 -0
- {dslighting-1.7.1.dist-info → dslighting-1.7.6.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,425 @@
|
|
|
1
|
+
# Verbatim from: https://github.com/lyft/nuscenes-devkit/blob/master/lyft_dataset_sdk/eval/detection/mAP_evaluation.py
|
|
2
|
+
"""
|
|
3
|
+
mAP 3D calculation for the data in nuScenes format.
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
The intput files expected to have the format:
|
|
7
|
+
|
|
8
|
+
Expected fields:
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
gt = [{
|
|
12
|
+
'sample_token': '0f0e3ce89d2324d8b45aa55a7b4f8207fbb039a550991a5149214f98cec136ac',
|
|
13
|
+
'translation': [974.2811881299899, 1714.6815014457964, -23.689857123368846],
|
|
14
|
+
'size': [1.796, 4.488, 1.664],
|
|
15
|
+
'rotation': [0.14882026466054782, 0, 0, 0.9888642620837121],
|
|
16
|
+
'name': 'car'
|
|
17
|
+
}]
|
|
18
|
+
|
|
19
|
+
prediction_result = {
|
|
20
|
+
'sample_token': '0f0e3ce89d2324d8b45aa55a7b4f8207fbb039a550991a5149214f98cec136ac',
|
|
21
|
+
'translation': [971.8343488872263, 1713.6816097857359, -25.82534357061308],
|
|
22
|
+
'size': [2.519726579986132, 7.810161372666739, 3.483438286096803],
|
|
23
|
+
'rotation': [0.10913582721095375, 0.04099572636992043, 0.01927712319721745, 1.029328402625659],
|
|
24
|
+
'name': 'car',
|
|
25
|
+
'score': 0.3077029437237213
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
input arguments:
|
|
30
|
+
|
|
31
|
+
--pred_file: file with predictions
|
|
32
|
+
--gt_file: ground truth file
|
|
33
|
+
--iou_threshold: IOU threshold
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
In general we would be interested in average of mAP at thresholds [0.5, 0.55, 0.6, 0.65,...0.95], similar to the
|
|
37
|
+
standard COCO => one needs to run this file N times for every IOU threshold independently.
|
|
38
|
+
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
import argparse
|
|
42
|
+
import json
|
|
43
|
+
from collections import defaultdict
|
|
44
|
+
from pathlib import Path
|
|
45
|
+
|
|
46
|
+
import numpy as np
|
|
47
|
+
from pyquaternion import Quaternion
|
|
48
|
+
from shapely.geometry import Polygon
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class Box3D:
|
|
52
|
+
"""Data class used during detection evaluation. Can be a prediction or ground truth."""
|
|
53
|
+
|
|
54
|
+
def __init__(self, **kwargs):
|
|
55
|
+
sample_token = kwargs["sample_token"]
|
|
56
|
+
translation = kwargs["translation"]
|
|
57
|
+
size = kwargs["size"]
|
|
58
|
+
rotation = kwargs["rotation"]
|
|
59
|
+
name = kwargs["name"]
|
|
60
|
+
score = kwargs.get("score", -1)
|
|
61
|
+
|
|
62
|
+
if not isinstance(sample_token, str):
|
|
63
|
+
raise TypeError("Sample_token must be a string!")
|
|
64
|
+
|
|
65
|
+
if not len(translation) == 3:
|
|
66
|
+
raise ValueError("Translation must have 3 elements!")
|
|
67
|
+
|
|
68
|
+
if np.any(np.isnan(translation)):
|
|
69
|
+
raise ValueError("Translation may not be NaN!")
|
|
70
|
+
|
|
71
|
+
if not len(size) == 3:
|
|
72
|
+
raise ValueError("Size must have 3 elements!")
|
|
73
|
+
|
|
74
|
+
if np.any(np.isnan(size)):
|
|
75
|
+
raise ValueError("Size may not be NaN!")
|
|
76
|
+
|
|
77
|
+
if not len(rotation) == 4:
|
|
78
|
+
raise ValueError("Rotation must have 4 elements!")
|
|
79
|
+
|
|
80
|
+
if np.any(np.isnan(rotation)):
|
|
81
|
+
raise ValueError("Rotation may not be NaN!")
|
|
82
|
+
|
|
83
|
+
if name is None:
|
|
84
|
+
raise ValueError("Name cannot be empty!")
|
|
85
|
+
|
|
86
|
+
# Assign.
|
|
87
|
+
self.sample_token = sample_token
|
|
88
|
+
self.translation = translation
|
|
89
|
+
self.size = size
|
|
90
|
+
self.volume = np.prod(self.size)
|
|
91
|
+
self.score = score
|
|
92
|
+
|
|
93
|
+
assert np.all([x > 0 for x in size])
|
|
94
|
+
self.rotation = rotation
|
|
95
|
+
self.name = name
|
|
96
|
+
self.quaternion = Quaternion(self.rotation)
|
|
97
|
+
|
|
98
|
+
self.width, self.length, self.height = size
|
|
99
|
+
|
|
100
|
+
self.center_x, self.center_y, self.center_z = self.translation
|
|
101
|
+
|
|
102
|
+
self.min_z = self.center_z - self.height / 2
|
|
103
|
+
self.max_z = self.center_z + self.height / 2
|
|
104
|
+
|
|
105
|
+
self.ground_bbox_coords = None
|
|
106
|
+
self.ground_bbox_coords = self.get_ground_bbox_coords()
|
|
107
|
+
|
|
108
|
+
@staticmethod
|
|
109
|
+
def check_orthogonal(a, b, c):
|
|
110
|
+
"""Check that vector (b - a) is orthogonal to the vector (c - a)."""
|
|
111
|
+
return np.isclose((b[0] - a[0]) * (c[0] - a[0]) + (b[1] - a[1]) * (c[1] - a[1]), 0)
|
|
112
|
+
|
|
113
|
+
def get_ground_bbox_coords(self):
|
|
114
|
+
if self.ground_bbox_coords is not None:
|
|
115
|
+
return self.ground_bbox_coords
|
|
116
|
+
return self.calculate_ground_bbox_coords()
|
|
117
|
+
|
|
118
|
+
def calculate_ground_bbox_coords(self):
|
|
119
|
+
"""We assume that the 3D box has lower plane parallel to the ground.
|
|
120
|
+
|
|
121
|
+
Returns: Polygon with 4 points describing the base.
|
|
122
|
+
|
|
123
|
+
"""
|
|
124
|
+
if self.ground_bbox_coords is not None:
|
|
125
|
+
return self.ground_bbox_coords
|
|
126
|
+
|
|
127
|
+
rotation_matrix = self.quaternion.rotation_matrix
|
|
128
|
+
|
|
129
|
+
cos_angle = rotation_matrix[0, 0]
|
|
130
|
+
sin_angle = rotation_matrix[1, 0]
|
|
131
|
+
|
|
132
|
+
point_0_x = self.center_x + self.length / 2 * cos_angle + self.width / 2 * sin_angle
|
|
133
|
+
point_0_y = self.center_y + self.length / 2 * sin_angle - self.width / 2 * cos_angle
|
|
134
|
+
|
|
135
|
+
point_1_x = self.center_x + self.length / 2 * cos_angle - self.width / 2 * sin_angle
|
|
136
|
+
point_1_y = self.center_y + self.length / 2 * sin_angle + self.width / 2 * cos_angle
|
|
137
|
+
|
|
138
|
+
point_2_x = self.center_x - self.length / 2 * cos_angle - self.width / 2 * sin_angle
|
|
139
|
+
point_2_y = self.center_y - self.length / 2 * sin_angle + self.width / 2 * cos_angle
|
|
140
|
+
|
|
141
|
+
point_3_x = self.center_x - self.length / 2 * cos_angle + self.width / 2 * sin_angle
|
|
142
|
+
point_3_y = self.center_y - self.length / 2 * sin_angle - self.width / 2 * cos_angle
|
|
143
|
+
|
|
144
|
+
point_0 = point_0_x, point_0_y
|
|
145
|
+
point_1 = point_1_x, point_1_y
|
|
146
|
+
point_2 = point_2_x, point_2_y
|
|
147
|
+
point_3 = point_3_x, point_3_y
|
|
148
|
+
|
|
149
|
+
assert self.check_orthogonal(point_0, point_1, point_3)
|
|
150
|
+
assert self.check_orthogonal(point_1, point_0, point_2)
|
|
151
|
+
assert self.check_orthogonal(point_2, point_1, point_3)
|
|
152
|
+
assert self.check_orthogonal(point_3, point_0, point_2)
|
|
153
|
+
|
|
154
|
+
self.ground_bbox_coords = Polygon(
|
|
155
|
+
[
|
|
156
|
+
(point_0_x, point_0_y),
|
|
157
|
+
(point_1_x, point_1_y),
|
|
158
|
+
(point_2_x, point_2_y),
|
|
159
|
+
(point_3_x, point_3_y),
|
|
160
|
+
(point_0_x, point_0_y),
|
|
161
|
+
]
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
return self.ground_bbox_coords
|
|
165
|
+
|
|
166
|
+
def get_height_intersection(self, other):
|
|
167
|
+
min_z = max(other.min_z, self.min_z)
|
|
168
|
+
max_z = min(other.max_z, self.max_z)
|
|
169
|
+
|
|
170
|
+
return max(0, max_z - min_z)
|
|
171
|
+
|
|
172
|
+
def get_area_intersection(self, other) -> float:
|
|
173
|
+
result = self.ground_bbox_coords.intersection(other.ground_bbox_coords).area
|
|
174
|
+
|
|
175
|
+
assert result <= self.width * self.length
|
|
176
|
+
|
|
177
|
+
return result
|
|
178
|
+
|
|
179
|
+
def get_intersection(self, other) -> float:
|
|
180
|
+
height_intersection = self.get_height_intersection(other)
|
|
181
|
+
|
|
182
|
+
area_intersection = self.ground_bbox_coords.intersection(other.ground_bbox_coords).area
|
|
183
|
+
|
|
184
|
+
return height_intersection * area_intersection
|
|
185
|
+
|
|
186
|
+
def get_iou(self, other):
|
|
187
|
+
intersection = self.get_intersection(other)
|
|
188
|
+
union = self.volume + other.volume - intersection
|
|
189
|
+
|
|
190
|
+
iou = np.clip(intersection / union, 0, 1)
|
|
191
|
+
|
|
192
|
+
return iou
|
|
193
|
+
|
|
194
|
+
def __repr__(self):
|
|
195
|
+
return str(self.serialize())
|
|
196
|
+
|
|
197
|
+
def serialize(self) -> dict:
|
|
198
|
+
"""Returns: Serialized instance as dict."""
|
|
199
|
+
|
|
200
|
+
return {
|
|
201
|
+
"sample_token": self.sample_token,
|
|
202
|
+
"translation": self.translation,
|
|
203
|
+
"size": self.size,
|
|
204
|
+
"rotation": self.rotation,
|
|
205
|
+
"name": self.name,
|
|
206
|
+
"volume": self.volume,
|
|
207
|
+
"score": self.score,
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def group_by_key(detections, key):
|
|
212
|
+
groups = defaultdict(list)
|
|
213
|
+
for detection in detections:
|
|
214
|
+
groups[detection[key]].append(detection)
|
|
215
|
+
return groups
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def wrap_in_box(input):
|
|
219
|
+
result = {}
|
|
220
|
+
for key, value in input.items():
|
|
221
|
+
result[key] = [Box3D(**x) for x in value]
|
|
222
|
+
|
|
223
|
+
return result
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def get_envelope(precisions):
|
|
227
|
+
"""Compute the precision envelope.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
precisions:
|
|
231
|
+
|
|
232
|
+
Returns:
|
|
233
|
+
|
|
234
|
+
"""
|
|
235
|
+
for i in range(precisions.size - 1, 0, -1):
|
|
236
|
+
precisions[i - 1] = np.maximum(precisions[i - 1], precisions[i])
|
|
237
|
+
return precisions
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def get_ap(recalls, precisions):
|
|
241
|
+
"""Calculate average precision.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
recalls:
|
|
245
|
+
precisions: Returns (float): average precision.
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
|
|
249
|
+
"""
|
|
250
|
+
# correct AP calculation
|
|
251
|
+
# first append sentinel values at the end
|
|
252
|
+
recalls = np.concatenate(([0.0], recalls, [1.0]))
|
|
253
|
+
precisions = np.concatenate(([0.0], precisions, [0.0]))
|
|
254
|
+
|
|
255
|
+
precisions = get_envelope(precisions)
|
|
256
|
+
|
|
257
|
+
# to calculate area under PR curve, look for points where X axis (recall) changes value
|
|
258
|
+
i = np.where(recalls[1:] != recalls[:-1])[0]
|
|
259
|
+
|
|
260
|
+
# and sum (\Delta recall) * prec
|
|
261
|
+
ap = np.sum((recalls[i + 1] - recalls[i]) * precisions[i + 1])
|
|
262
|
+
return ap
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def get_ious(gt_boxes, predicted_box):
|
|
266
|
+
return [predicted_box.get_iou(x) for x in gt_boxes]
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def recall_precision(gt, predictions, iou_threshold):
|
|
270
|
+
num_gts = len(gt)
|
|
271
|
+
image_gts = group_by_key(gt, "sample_token")
|
|
272
|
+
image_gts = wrap_in_box(image_gts)
|
|
273
|
+
|
|
274
|
+
sample_gt_checked = {
|
|
275
|
+
sample_token: np.zeros(len(boxes)) for sample_token, boxes in image_gts.items()
|
|
276
|
+
}
|
|
277
|
+
|
|
278
|
+
predictions = sorted(predictions, key=lambda x: x["score"], reverse=True)
|
|
279
|
+
|
|
280
|
+
# go down dets and mark TPs and FPs
|
|
281
|
+
num_predictions = len(predictions)
|
|
282
|
+
tp = np.zeros(num_predictions)
|
|
283
|
+
fp = np.zeros(num_predictions)
|
|
284
|
+
|
|
285
|
+
for prediction_index, prediction in enumerate(predictions):
|
|
286
|
+
predicted_box = Box3D(**prediction)
|
|
287
|
+
|
|
288
|
+
sample_token = prediction["sample_token"]
|
|
289
|
+
|
|
290
|
+
max_overlap = -np.inf
|
|
291
|
+
jmax = -1
|
|
292
|
+
|
|
293
|
+
try:
|
|
294
|
+
gt_boxes = image_gts[sample_token] # gt_boxes per sample
|
|
295
|
+
gt_checked = sample_gt_checked[sample_token] # gt flags per sample
|
|
296
|
+
except KeyError:
|
|
297
|
+
gt_boxes = []
|
|
298
|
+
gt_checked = None
|
|
299
|
+
|
|
300
|
+
if len(gt_boxes) > 0:
|
|
301
|
+
overlaps = get_ious(gt_boxes, predicted_box)
|
|
302
|
+
|
|
303
|
+
max_overlap = np.max(overlaps)
|
|
304
|
+
|
|
305
|
+
jmax = np.argmax(overlaps)
|
|
306
|
+
|
|
307
|
+
if max_overlap > iou_threshold:
|
|
308
|
+
if gt_checked[jmax] == 0:
|
|
309
|
+
tp[prediction_index] = 1.0
|
|
310
|
+
gt_checked[jmax] = 1
|
|
311
|
+
else:
|
|
312
|
+
fp[prediction_index] = 1.0
|
|
313
|
+
else:
|
|
314
|
+
fp[prediction_index] = 1.0
|
|
315
|
+
|
|
316
|
+
# compute precision recall
|
|
317
|
+
fp = np.cumsum(fp, axis=0)
|
|
318
|
+
tp = np.cumsum(tp, axis=0)
|
|
319
|
+
|
|
320
|
+
recalls = tp / float(num_gts)
|
|
321
|
+
|
|
322
|
+
assert np.all(0 <= recalls) & np.all(recalls <= 1)
|
|
323
|
+
|
|
324
|
+
# avoid divide by zero in case the first detection matches a difficult ground truth
|
|
325
|
+
precisions = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
|
|
326
|
+
|
|
327
|
+
assert np.all(0 <= precisions) & np.all(precisions <= 1)
|
|
328
|
+
|
|
329
|
+
ap = get_ap(recalls, precisions)
|
|
330
|
+
|
|
331
|
+
return recalls, precisions, ap
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
def get_average_precisions(
|
|
335
|
+
gt: list, predictions: list, class_names: list, iou_threshold: float
|
|
336
|
+
) -> np.array:
|
|
337
|
+
"""Returns an array with an average precision per class.
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
Args:
|
|
341
|
+
gt: list of dictionaries in the format described below.
|
|
342
|
+
predictions: list of dictionaries in the format described below.
|
|
343
|
+
class_names: list of the class names.
|
|
344
|
+
iou_threshold: IOU threshold used to calculate TP / FN
|
|
345
|
+
|
|
346
|
+
Returns an array with an average precision per class.
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
Ground truth and predictions should have schema:
|
|
350
|
+
|
|
351
|
+
gt = [{
|
|
352
|
+
'sample_token': '0f0e3ce89d2324d8b45aa55a7b4f8207fbb039a550991a5149214f98cec136ac',
|
|
353
|
+
'translation': [974.2811881299899, 1714.6815014457964, -23.689857123368846],
|
|
354
|
+
'size': [1.796, 4.488, 1.664],
|
|
355
|
+
'rotation': [0.14882026466054782, 0, 0, 0.9888642620837121],
|
|
356
|
+
'name': 'car'
|
|
357
|
+
}]
|
|
358
|
+
|
|
359
|
+
predictions = [{
|
|
360
|
+
'sample_token': '0f0e3ce89d2324d8b45aa55a7b4f8207fbb039a550991a5149214f98cec136ac',
|
|
361
|
+
'translation': [971.8343488872263, 1713.6816097857359, -25.82534357061308],
|
|
362
|
+
'size': [2.519726579986132, 7.810161372666739, 3.483438286096803],
|
|
363
|
+
'rotation': [0.10913582721095375, 0.04099572636992043, 0.01927712319721745, 1.029328402625659],
|
|
364
|
+
'name': 'car',
|
|
365
|
+
'score': 0.3077029437237213
|
|
366
|
+
}]
|
|
367
|
+
|
|
368
|
+
"""
|
|
369
|
+
assert 0 <= iou_threshold <= 1
|
|
370
|
+
|
|
371
|
+
gt_by_class_name = group_by_key(gt, "name")
|
|
372
|
+
pred_by_class_name = group_by_key(predictions, "name")
|
|
373
|
+
|
|
374
|
+
average_precisions = np.zeros(len(class_names))
|
|
375
|
+
|
|
376
|
+
for class_id, class_name in enumerate(class_names):
|
|
377
|
+
if class_name in pred_by_class_name:
|
|
378
|
+
recalls, precisions, average_precision = recall_precision(
|
|
379
|
+
gt_by_class_name[class_name], pred_by_class_name[class_name], iou_threshold
|
|
380
|
+
)
|
|
381
|
+
average_precisions[class_id] = average_precision
|
|
382
|
+
|
|
383
|
+
return average_precisions
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
def get_class_names(gt: dict) -> list:
|
|
387
|
+
"""Get sorted list of class names.
|
|
388
|
+
|
|
389
|
+
Args:
|
|
390
|
+
gt:
|
|
391
|
+
|
|
392
|
+
Returns: Sorted list of class names.
|
|
393
|
+
|
|
394
|
+
"""
|
|
395
|
+
return sorted(list(set([x["name"] for x in gt])))
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
if __name__ == "__main__":
|
|
399
|
+
parser = argparse.ArgumentParser()
|
|
400
|
+
arg = parser.add_argument
|
|
401
|
+
arg("-p", "--pred_file", type=str, help="Path to the predictions file.", required=True)
|
|
402
|
+
arg("-g", "--gt_file", type=str, help="Path to the ground truth file.", required=True)
|
|
403
|
+
arg("-t", "--iou_threshold", type=float, help="iou threshold", default=0.5)
|
|
404
|
+
|
|
405
|
+
args = parser.parse_args()
|
|
406
|
+
|
|
407
|
+
gt_path = Path(args.gt_file)
|
|
408
|
+
pred_path = Path(args.pred_file)
|
|
409
|
+
|
|
410
|
+
with open(args.pred_file) as f:
|
|
411
|
+
predictions = json.load(f)
|
|
412
|
+
|
|
413
|
+
with open(args.gt_file) as f:
|
|
414
|
+
gt = json.load(f)
|
|
415
|
+
|
|
416
|
+
class_names = get_class_names(gt)
|
|
417
|
+
print("Class_names = ", class_names)
|
|
418
|
+
|
|
419
|
+
average_precisions = get_average_precisions(gt, predictions, class_names, args.iou_threshold)
|
|
420
|
+
|
|
421
|
+
mAP = np.mean(average_precisions)
|
|
422
|
+
print("Average per class mean average precision = ", mAP)
|
|
423
|
+
|
|
424
|
+
for class_id in sorted(list(zip(class_names, average_precisions.flatten().tolist()))):
|
|
425
|
+
print(class_id)
|