clarifai 11.3.0rc2__py3-none-any.whl → 11.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- clarifai/__init__.py +1 -1
- clarifai/cli/__main__.py +1 -1
- clarifai/cli/base.py +144 -136
- clarifai/cli/compute_cluster.py +45 -31
- clarifai/cli/deployment.py +93 -76
- clarifai/cli/model.py +578 -180
- clarifai/cli/nodepool.py +100 -82
- clarifai/client/__init__.py +12 -2
- clarifai/client/app.py +973 -911
- clarifai/client/auth/helper.py +345 -342
- clarifai/client/auth/register.py +7 -7
- clarifai/client/auth/stub.py +107 -106
- clarifai/client/base.py +185 -178
- clarifai/client/compute_cluster.py +214 -180
- clarifai/client/dataset.py +793 -698
- clarifai/client/deployment.py +55 -50
- clarifai/client/input.py +1223 -1088
- clarifai/client/lister.py +47 -45
- clarifai/client/model.py +1939 -1717
- clarifai/client/model_client.py +525 -502
- clarifai/client/module.py +82 -73
- clarifai/client/nodepool.py +358 -213
- clarifai/client/runner.py +58 -0
- clarifai/client/search.py +342 -309
- clarifai/client/user.py +419 -414
- clarifai/client/workflow.py +294 -274
- clarifai/constants/dataset.py +11 -17
- clarifai/constants/model.py +8 -2
- clarifai/datasets/export/inputs_annotations.py +233 -217
- clarifai/datasets/upload/base.py +63 -51
- clarifai/datasets/upload/features.py +43 -38
- clarifai/datasets/upload/image.py +237 -207
- clarifai/datasets/upload/loaders/coco_captions.py +34 -32
- clarifai/datasets/upload/loaders/coco_detection.py +72 -65
- clarifai/datasets/upload/loaders/imagenet_classification.py +57 -53
- clarifai/datasets/upload/loaders/xview_detection.py +274 -132
- clarifai/datasets/upload/multimodal.py +55 -46
- clarifai/datasets/upload/text.py +55 -47
- clarifai/datasets/upload/utils.py +250 -234
- clarifai/errors.py +51 -50
- clarifai/models/api.py +260 -238
- clarifai/modules/css.py +50 -50
- clarifai/modules/pages.py +33 -33
- clarifai/rag/rag.py +312 -288
- clarifai/rag/utils.py +91 -84
- clarifai/runners/models/model_builder.py +906 -802
- clarifai/runners/models/model_class.py +370 -331
- clarifai/runners/models/model_run_locally.py +459 -419
- clarifai/runners/models/model_runner.py +170 -162
- clarifai/runners/models/model_servicer.py +78 -70
- clarifai/runners/server.py +111 -101
- clarifai/runners/utils/code_script.py +225 -187
- clarifai/runners/utils/const.py +4 -1
- clarifai/runners/utils/data_types/__init__.py +12 -0
- clarifai/runners/utils/data_types/data_types.py +598 -0
- clarifai/runners/utils/data_utils.py +387 -440
- clarifai/runners/utils/loader.py +247 -227
- clarifai/runners/utils/method_signatures.py +411 -386
- clarifai/runners/utils/openai_convertor.py +108 -109
- clarifai/runners/utils/serializers.py +175 -179
- clarifai/runners/utils/url_fetcher.py +35 -35
- clarifai/schema/search.py +56 -63
- clarifai/urls/helper.py +125 -102
- clarifai/utils/cli.py +129 -123
- clarifai/utils/config.py +127 -87
- clarifai/utils/constants.py +49 -0
- clarifai/utils/evaluation/helpers.py +503 -466
- clarifai/utils/evaluation/main.py +431 -393
- clarifai/utils/evaluation/testset_annotation_parser.py +154 -144
- clarifai/utils/logging.py +324 -306
- clarifai/utils/misc.py +60 -56
- clarifai/utils/model_train.py +165 -146
- clarifai/utils/protobuf.py +126 -103
- clarifai/versions.py +3 -1
- clarifai/workflows/export.py +48 -50
- clarifai/workflows/utils.py +39 -36
- clarifai/workflows/validate.py +55 -43
- {clarifai-11.3.0rc2.dist-info → clarifai-11.4.0.dist-info}/METADATA +16 -6
- clarifai-11.4.0.dist-info/RECORD +109 -0
- {clarifai-11.3.0rc2.dist-info → clarifai-11.4.0.dist-info}/WHEEL +1 -1
- clarifai/__pycache__/__init__.cpython-310.pyc +0 -0
- clarifai/__pycache__/__init__.cpython-311.pyc +0 -0
- clarifai/__pycache__/__init__.cpython-39.pyc +0 -0
- clarifai/__pycache__/errors.cpython-310.pyc +0 -0
- clarifai/__pycache__/errors.cpython-311.pyc +0 -0
- clarifai/__pycache__/versions.cpython-310.pyc +0 -0
- clarifai/__pycache__/versions.cpython-311.pyc +0 -0
- clarifai/cli/__pycache__/__init__.cpython-310.pyc +0 -0
- clarifai/cli/__pycache__/__init__.cpython-311.pyc +0 -0
- clarifai/cli/__pycache__/base.cpython-310.pyc +0 -0
- clarifai/cli/__pycache__/base.cpython-311.pyc +0 -0
- clarifai/cli/__pycache__/base_cli.cpython-310.pyc +0 -0
- clarifai/cli/__pycache__/compute_cluster.cpython-310.pyc +0 -0
- clarifai/cli/__pycache__/compute_cluster.cpython-311.pyc +0 -0
- clarifai/cli/__pycache__/deployment.cpython-310.pyc +0 -0
- clarifai/cli/__pycache__/deployment.cpython-311.pyc +0 -0
- clarifai/cli/__pycache__/model.cpython-310.pyc +0 -0
- clarifai/cli/__pycache__/model.cpython-311.pyc +0 -0
- clarifai/cli/__pycache__/model_cli.cpython-310.pyc +0 -0
- clarifai/cli/__pycache__/nodepool.cpython-310.pyc +0 -0
- clarifai/cli/__pycache__/nodepool.cpython-311.pyc +0 -0
- clarifai/client/__pycache__/__init__.cpython-310.pyc +0 -0
- clarifai/client/__pycache__/__init__.cpython-311.pyc +0 -0
- clarifai/client/__pycache__/__init__.cpython-39.pyc +0 -0
- clarifai/client/__pycache__/app.cpython-310.pyc +0 -0
- clarifai/client/__pycache__/app.cpython-311.pyc +0 -0
- clarifai/client/__pycache__/app.cpython-39.pyc +0 -0
- clarifai/client/__pycache__/base.cpython-310.pyc +0 -0
- clarifai/client/__pycache__/base.cpython-311.pyc +0 -0
- clarifai/client/__pycache__/compute_cluster.cpython-310.pyc +0 -0
- clarifai/client/__pycache__/compute_cluster.cpython-311.pyc +0 -0
- clarifai/client/__pycache__/dataset.cpython-310.pyc +0 -0
- clarifai/client/__pycache__/dataset.cpython-311.pyc +0 -0
- clarifai/client/__pycache__/deployment.cpython-310.pyc +0 -0
- clarifai/client/__pycache__/deployment.cpython-311.pyc +0 -0
- clarifai/client/__pycache__/input.cpython-310.pyc +0 -0
- clarifai/client/__pycache__/input.cpython-311.pyc +0 -0
- clarifai/client/__pycache__/lister.cpython-310.pyc +0 -0
- clarifai/client/__pycache__/lister.cpython-311.pyc +0 -0
- clarifai/client/__pycache__/model.cpython-310.pyc +0 -0
- clarifai/client/__pycache__/model.cpython-311.pyc +0 -0
- clarifai/client/__pycache__/module.cpython-310.pyc +0 -0
- clarifai/client/__pycache__/module.cpython-311.pyc +0 -0
- clarifai/client/__pycache__/nodepool.cpython-310.pyc +0 -0
- clarifai/client/__pycache__/nodepool.cpython-311.pyc +0 -0
- clarifai/client/__pycache__/search.cpython-310.pyc +0 -0
- clarifai/client/__pycache__/search.cpython-311.pyc +0 -0
- clarifai/client/__pycache__/user.cpython-310.pyc +0 -0
- clarifai/client/__pycache__/user.cpython-311.pyc +0 -0
- clarifai/client/__pycache__/workflow.cpython-310.pyc +0 -0
- clarifai/client/__pycache__/workflow.cpython-311.pyc +0 -0
- clarifai/client/auth/__pycache__/__init__.cpython-310.pyc +0 -0
- clarifai/client/auth/__pycache__/__init__.cpython-311.pyc +0 -0
- clarifai/client/auth/__pycache__/helper.cpython-310.pyc +0 -0
- clarifai/client/auth/__pycache__/helper.cpython-311.pyc +0 -0
- clarifai/client/auth/__pycache__/register.cpython-310.pyc +0 -0
- clarifai/client/auth/__pycache__/register.cpython-311.pyc +0 -0
- clarifai/client/auth/__pycache__/stub.cpython-310.pyc +0 -0
- clarifai/client/auth/__pycache__/stub.cpython-311.pyc +0 -0
- clarifai/client/cli/__init__.py +0 -0
- clarifai/client/cli/__pycache__/__init__.cpython-310.pyc +0 -0
- clarifai/client/cli/__pycache__/base_cli.cpython-310.pyc +0 -0
- clarifai/client/cli/__pycache__/model_cli.cpython-310.pyc +0 -0
- clarifai/client/cli/base_cli.py +0 -88
- clarifai/client/cli/model_cli.py +0 -29
- clarifai/constants/__pycache__/base.cpython-310.pyc +0 -0
- clarifai/constants/__pycache__/base.cpython-311.pyc +0 -0
- clarifai/constants/__pycache__/dataset.cpython-310.pyc +0 -0
- clarifai/constants/__pycache__/dataset.cpython-311.pyc +0 -0
- clarifai/constants/__pycache__/input.cpython-310.pyc +0 -0
- clarifai/constants/__pycache__/input.cpython-311.pyc +0 -0
- clarifai/constants/__pycache__/model.cpython-310.pyc +0 -0
- clarifai/constants/__pycache__/model.cpython-311.pyc +0 -0
- clarifai/constants/__pycache__/rag.cpython-310.pyc +0 -0
- clarifai/constants/__pycache__/rag.cpython-311.pyc +0 -0
- clarifai/constants/__pycache__/search.cpython-310.pyc +0 -0
- clarifai/constants/__pycache__/search.cpython-311.pyc +0 -0
- clarifai/constants/__pycache__/workflow.cpython-310.pyc +0 -0
- clarifai/constants/__pycache__/workflow.cpython-311.pyc +0 -0
- clarifai/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
- clarifai/datasets/__pycache__/__init__.cpython-311.pyc +0 -0
- clarifai/datasets/__pycache__/__init__.cpython-39.pyc +0 -0
- clarifai/datasets/export/__pycache__/__init__.cpython-310.pyc +0 -0
- clarifai/datasets/export/__pycache__/__init__.cpython-311.pyc +0 -0
- clarifai/datasets/export/__pycache__/__init__.cpython-39.pyc +0 -0
- clarifai/datasets/export/__pycache__/inputs_annotations.cpython-310.pyc +0 -0
- clarifai/datasets/export/__pycache__/inputs_annotations.cpython-311.pyc +0 -0
- clarifai/datasets/upload/__pycache__/__init__.cpython-310.pyc +0 -0
- clarifai/datasets/upload/__pycache__/__init__.cpython-311.pyc +0 -0
- clarifai/datasets/upload/__pycache__/__init__.cpython-39.pyc +0 -0
- clarifai/datasets/upload/__pycache__/base.cpython-310.pyc +0 -0
- clarifai/datasets/upload/__pycache__/base.cpython-311.pyc +0 -0
- clarifai/datasets/upload/__pycache__/features.cpython-310.pyc +0 -0
- clarifai/datasets/upload/__pycache__/features.cpython-311.pyc +0 -0
- clarifai/datasets/upload/__pycache__/image.cpython-310.pyc +0 -0
- clarifai/datasets/upload/__pycache__/image.cpython-311.pyc +0 -0
- clarifai/datasets/upload/__pycache__/multimodal.cpython-310.pyc +0 -0
- clarifai/datasets/upload/__pycache__/multimodal.cpython-311.pyc +0 -0
- clarifai/datasets/upload/__pycache__/text.cpython-310.pyc +0 -0
- clarifai/datasets/upload/__pycache__/text.cpython-311.pyc +0 -0
- clarifai/datasets/upload/__pycache__/utils.cpython-310.pyc +0 -0
- clarifai/datasets/upload/__pycache__/utils.cpython-311.pyc +0 -0
- clarifai/datasets/upload/loaders/__pycache__/__init__.cpython-311.pyc +0 -0
- clarifai/datasets/upload/loaders/__pycache__/__init__.cpython-39.pyc +0 -0
- clarifai/datasets/upload/loaders/__pycache__/coco_detection.cpython-311.pyc +0 -0
- clarifai/datasets/upload/loaders/__pycache__/imagenet_classification.cpython-311.pyc +0 -0
- clarifai/models/__pycache__/__init__.cpython-39.pyc +0 -0
- clarifai/modules/__pycache__/__init__.cpython-39.pyc +0 -0
- clarifai/rag/__pycache__/__init__.cpython-310.pyc +0 -0
- clarifai/rag/__pycache__/__init__.cpython-311.pyc +0 -0
- clarifai/rag/__pycache__/__init__.cpython-39.pyc +0 -0
- clarifai/rag/__pycache__/rag.cpython-310.pyc +0 -0
- clarifai/rag/__pycache__/rag.cpython-311.pyc +0 -0
- clarifai/rag/__pycache__/rag.cpython-39.pyc +0 -0
- clarifai/rag/__pycache__/utils.cpython-310.pyc +0 -0
- clarifai/rag/__pycache__/utils.cpython-311.pyc +0 -0
- clarifai/runners/__pycache__/__init__.cpython-310.pyc +0 -0
- clarifai/runners/__pycache__/__init__.cpython-311.pyc +0 -0
- clarifai/runners/__pycache__/__init__.cpython-39.pyc +0 -0
- clarifai/runners/dockerfile_template/Dockerfile.cpu.template +0 -31
- clarifai/runners/dockerfile_template/Dockerfile.cuda.template +0 -42
- clarifai/runners/dockerfile_template/Dockerfile.nim +0 -71
- clarifai/runners/models/__pycache__/__init__.cpython-310.pyc +0 -0
- clarifai/runners/models/__pycache__/__init__.cpython-311.pyc +0 -0
- clarifai/runners/models/__pycache__/__init__.cpython-39.pyc +0 -0
- clarifai/runners/models/__pycache__/base_typed_model.cpython-310.pyc +0 -0
- clarifai/runners/models/__pycache__/base_typed_model.cpython-311.pyc +0 -0
- clarifai/runners/models/__pycache__/base_typed_model.cpython-39.pyc +0 -0
- clarifai/runners/models/__pycache__/model_builder.cpython-311.pyc +0 -0
- clarifai/runners/models/__pycache__/model_class.cpython-310.pyc +0 -0
- clarifai/runners/models/__pycache__/model_class.cpython-311.pyc +0 -0
- clarifai/runners/models/__pycache__/model_run_locally.cpython-310-pytest-7.1.2.pyc +0 -0
- clarifai/runners/models/__pycache__/model_run_locally.cpython-310.pyc +0 -0
- clarifai/runners/models/__pycache__/model_run_locally.cpython-311.pyc +0 -0
- clarifai/runners/models/__pycache__/model_runner.cpython-310.pyc +0 -0
- clarifai/runners/models/__pycache__/model_runner.cpython-311.pyc +0 -0
- clarifai/runners/models/__pycache__/model_upload.cpython-310.pyc +0 -0
- clarifai/runners/models/base_typed_model.py +0 -238
- clarifai/runners/models/model_class_refract.py +0 -80
- clarifai/runners/models/model_upload.py +0 -607
- clarifai/runners/models/temp.py +0 -25
- clarifai/runners/utils/__pycache__/__init__.cpython-310.pyc +0 -0
- clarifai/runners/utils/__pycache__/__init__.cpython-311.pyc +0 -0
- clarifai/runners/utils/__pycache__/__init__.cpython-38.pyc +0 -0
- clarifai/runners/utils/__pycache__/__init__.cpython-39.pyc +0 -0
- clarifai/runners/utils/__pycache__/buffered_stream.cpython-310.pyc +0 -0
- clarifai/runners/utils/__pycache__/buffered_stream.cpython-38.pyc +0 -0
- clarifai/runners/utils/__pycache__/buffered_stream.cpython-39.pyc +0 -0
- clarifai/runners/utils/__pycache__/const.cpython-310.pyc +0 -0
- clarifai/runners/utils/__pycache__/const.cpython-311.pyc +0 -0
- clarifai/runners/utils/__pycache__/constants.cpython-310.pyc +0 -0
- clarifai/runners/utils/__pycache__/constants.cpython-38.pyc +0 -0
- clarifai/runners/utils/__pycache__/constants.cpython-39.pyc +0 -0
- clarifai/runners/utils/__pycache__/data_handler.cpython-310.pyc +0 -0
- clarifai/runners/utils/__pycache__/data_handler.cpython-311.pyc +0 -0
- clarifai/runners/utils/__pycache__/data_handler.cpython-38.pyc +0 -0
- clarifai/runners/utils/__pycache__/data_handler.cpython-39.pyc +0 -0
- clarifai/runners/utils/__pycache__/data_utils.cpython-310.pyc +0 -0
- clarifai/runners/utils/__pycache__/data_utils.cpython-311.pyc +0 -0
- clarifai/runners/utils/__pycache__/data_utils.cpython-38.pyc +0 -0
- clarifai/runners/utils/__pycache__/data_utils.cpython-39.pyc +0 -0
- clarifai/runners/utils/__pycache__/grpc_server.cpython-310.pyc +0 -0
- clarifai/runners/utils/__pycache__/grpc_server.cpython-38.pyc +0 -0
- clarifai/runners/utils/__pycache__/grpc_server.cpython-39.pyc +0 -0
- clarifai/runners/utils/__pycache__/health.cpython-310.pyc +0 -0
- clarifai/runners/utils/__pycache__/health.cpython-38.pyc +0 -0
- clarifai/runners/utils/__pycache__/health.cpython-39.pyc +0 -0
- clarifai/runners/utils/__pycache__/loader.cpython-310.pyc +0 -0
- clarifai/runners/utils/__pycache__/loader.cpython-311.pyc +0 -0
- clarifai/runners/utils/__pycache__/logging.cpython-310.pyc +0 -0
- clarifai/runners/utils/__pycache__/logging.cpython-38.pyc +0 -0
- clarifai/runners/utils/__pycache__/logging.cpython-39.pyc +0 -0
- clarifai/runners/utils/__pycache__/stream_source.cpython-310.pyc +0 -0
- clarifai/runners/utils/__pycache__/stream_source.cpython-39.pyc +0 -0
- clarifai/runners/utils/__pycache__/url_fetcher.cpython-310.pyc +0 -0
- clarifai/runners/utils/__pycache__/url_fetcher.cpython-311.pyc +0 -0
- clarifai/runners/utils/__pycache__/url_fetcher.cpython-38.pyc +0 -0
- clarifai/runners/utils/__pycache__/url_fetcher.cpython-39.pyc +0 -0
- clarifai/runners/utils/data_handler.py +0 -231
- clarifai/runners/utils/data_handler_refract.py +0 -213
- clarifai/runners/utils/data_types.py +0 -469
- clarifai/runners/utils/logger.py +0 -0
- clarifai/runners/utils/openai_format.py +0 -87
- clarifai/schema/__pycache__/search.cpython-310.pyc +0 -0
- clarifai/schema/__pycache__/search.cpython-311.pyc +0 -0
- clarifai/urls/__pycache__/helper.cpython-310.pyc +0 -0
- clarifai/urls/__pycache__/helper.cpython-311.pyc +0 -0
- clarifai/utils/__pycache__/__init__.cpython-310.pyc +0 -0
- clarifai/utils/__pycache__/__init__.cpython-311.pyc +0 -0
- clarifai/utils/__pycache__/__init__.cpython-39.pyc +0 -0
- clarifai/utils/__pycache__/cli.cpython-310.pyc +0 -0
- clarifai/utils/__pycache__/cli.cpython-311.pyc +0 -0
- clarifai/utils/__pycache__/config.cpython-311.pyc +0 -0
- clarifai/utils/__pycache__/constants.cpython-310.pyc +0 -0
- clarifai/utils/__pycache__/constants.cpython-311.pyc +0 -0
- clarifai/utils/__pycache__/logging.cpython-310.pyc +0 -0
- clarifai/utils/__pycache__/logging.cpython-311.pyc +0 -0
- clarifai/utils/__pycache__/misc.cpython-310.pyc +0 -0
- clarifai/utils/__pycache__/misc.cpython-311.pyc +0 -0
- clarifai/utils/__pycache__/model_train.cpython-310.pyc +0 -0
- clarifai/utils/__pycache__/model_train.cpython-311.pyc +0 -0
- clarifai/utils/__pycache__/protobuf.cpython-311.pyc +0 -0
- clarifai/utils/evaluation/__pycache__/__init__.cpython-311.pyc +0 -0
- clarifai/utils/evaluation/__pycache__/__init__.cpython-39.pyc +0 -0
- clarifai/utils/evaluation/__pycache__/helpers.cpython-311.pyc +0 -0
- clarifai/utils/evaluation/__pycache__/main.cpython-311.pyc +0 -0
- clarifai/utils/evaluation/__pycache__/main.cpython-39.pyc +0 -0
- clarifai/workflows/__pycache__/__init__.cpython-310.pyc +0 -0
- clarifai/workflows/__pycache__/__init__.cpython-311.pyc +0 -0
- clarifai/workflows/__pycache__/__init__.cpython-39.pyc +0 -0
- clarifai/workflows/__pycache__/export.cpython-310.pyc +0 -0
- clarifai/workflows/__pycache__/export.cpython-311.pyc +0 -0
- clarifai/workflows/__pycache__/utils.cpython-310.pyc +0 -0
- clarifai/workflows/__pycache__/utils.cpython-311.pyc +0 -0
- clarifai/workflows/__pycache__/validate.cpython-310.pyc +0 -0
- clarifai/workflows/__pycache__/validate.cpython-311.pyc +0 -0
- clarifai-11.3.0rc2.dist-info/RECORD +0 -322
- {clarifai-11.3.0rc2.dist-info → clarifai-11.4.0.dist-info}/entry_points.txt +0 -0
- {clarifai-11.3.0rc2.dist-info → clarifai-11.4.0.dist-info/licenses}/LICENSE +0 -0
- {clarifai-11.3.0rc2.dist-info → clarifai-11.4.0.dist-info}/top_level.txt +0 -0
@@ -13,515 +13,552 @@ from clarifai.client.dataset import Dataset
|
|
13
13
|
from clarifai.client.model import Model
|
14
14
|
|
15
15
|
try:
|
16
|
-
|
16
|
+
import pandas as pd
|
17
17
|
except ImportError:
|
18
|
-
|
18
|
+
raise ImportError("Can not import pandas. Please run `pip install pandas` to install it")
|
19
19
|
|
20
20
|
try:
|
21
|
-
|
21
|
+
from loguru import logger
|
22
22
|
except ImportError:
|
23
|
-
|
23
|
+
from ..logging import logger
|
24
24
|
|
25
25
|
MACRO_AVG = "macro_avg"
|
26
26
|
|
27
27
|
|
28
28
|
class EvalType(Enum):
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
29
|
+
UNDEFINED = 0
|
30
|
+
CLASSIFICATION = 1
|
31
|
+
DETECTION = 2
|
32
|
+
CLUSTERING = 3
|
33
|
+
SEGMENTATION = 4
|
34
|
+
TRACKER = 5
|
35
35
|
|
36
36
|
|
37
37
|
def get_eval_type(model_type):
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
38
|
+
if "classifier" in model_type:
|
39
|
+
return EvalType.CLASSIFICATION
|
40
|
+
elif "visual-detector" in model_type:
|
41
|
+
return EvalType.DETECTION
|
42
|
+
elif "segmenter" in model_type:
|
43
|
+
return EvalType.SEGMENTATION
|
44
|
+
elif "embedder" in model_type:
|
45
|
+
return EvalType.CLUSTERING
|
46
|
+
elif "tracker" in model_type:
|
47
|
+
return EvalType.TRACKER
|
48
|
+
else:
|
49
|
+
return EvalType.UNDEFINED
|
50
50
|
|
51
51
|
|
52
52
|
def to_file_name(x) -> str:
|
53
|
-
|
53
|
+
return x.replace('/', '--')
|
54
54
|
|
55
55
|
|
56
56
|
@dataclass
|
57
57
|
class _BaseEvalResultHandler:
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
)
|
88
|
-
|
89
|
-
def find_eval_id(self,
|
90
|
-
datasets: List[Dataset] = [],
|
91
|
-
attempt_evaluate: bool = False,
|
92
|
-
eval_info: dict = None):
|
93
|
-
list_eval_outputs = self.model.list_evaluations()
|
94
|
-
self.eval_data = []
|
95
|
-
for dataset in datasets:
|
96
|
-
dataset.app_id = dataset.app_id or self.model.auth_helper.app_id
|
97
|
-
dataset.user_id = dataset.user_id or self.model.auth_helper.user_id
|
98
|
-
dataset_assert_msg = dataset.dataset_info
|
99
|
-
# checking if dataset exists
|
100
|
-
out = dataset.list_versions()
|
101
|
-
try:
|
102
|
-
next(iter(out))
|
103
|
-
except Exception as e:
|
104
|
-
if any(["CONN_DOES_NOT_EXIST" in _e for _e in e.args]):
|
105
|
-
raise Exception(
|
106
|
-
f"Dataset {dataset_assert_msg} does not exists. Please check datasets args")
|
58
|
+
model: Model
|
59
|
+
eval_data: List[resources_pb2.EvalMetrics] = field(default_factory=list)
|
60
|
+
|
61
|
+
def evaluate_and_wait(self, dataset: Dataset, eval_info: dict = None):
|
62
|
+
from tqdm import tqdm
|
63
|
+
|
64
|
+
dataset_id = dataset.id
|
65
|
+
dataset_app_id = dataset.app_id
|
66
|
+
dataset_user_id = dataset.user_id
|
67
|
+
_ = self.model.evaluate(
|
68
|
+
dataset_id=dataset_id,
|
69
|
+
dataset_app_id=dataset_app_id,
|
70
|
+
dataset_user_id=dataset_user_id,
|
71
|
+
eval_info=eval_info,
|
72
|
+
)
|
73
|
+
latest_eval = self.model.list_evaluations()[0]
|
74
|
+
excepted = 10
|
75
|
+
desc = f"Please wait for the evaluation process between model {self.get_model_name()} and dataset {dataset_user_id}/{dataset_app_id}/{dataset_id} to complete."
|
76
|
+
bar = tqdm(total=excepted, desc=desc, leave=False, ncols=0)
|
77
|
+
while latest_eval.status.code in [
|
78
|
+
status_code_pb2.MODEL_EVALUATING,
|
79
|
+
status_code_pb2.MODEL_QUEUED_FOR_EVALUATION,
|
80
|
+
]:
|
81
|
+
latest_eval = self.model.list_evaluations()[0]
|
82
|
+
time.sleep(1)
|
83
|
+
bar.update(1)
|
84
|
+
|
85
|
+
if latest_eval.status.code == status_code_pb2.MODEL_EVALUATED:
|
86
|
+
return latest_eval
|
107
87
|
else:
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
self.
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
88
|
+
raise Exception(
|
89
|
+
f"Model has failed to evaluate \n {latest_eval.status}.\nPlease check your dataset inputs!"
|
90
|
+
)
|
91
|
+
|
92
|
+
def find_eval_id(
|
93
|
+
self, datasets: List[Dataset] = [], attempt_evaluate: bool = False, eval_info: dict = None
|
94
|
+
):
|
95
|
+
list_eval_outputs = self.model.list_evaluations()
|
96
|
+
self.eval_data = []
|
97
|
+
for dataset in datasets:
|
98
|
+
dataset.app_id = dataset.app_id or self.model.auth_helper.app_id
|
99
|
+
dataset.user_id = dataset.user_id or self.model.auth_helper.user_id
|
100
|
+
dataset_assert_msg = dataset.dataset_info
|
101
|
+
# checking if dataset exists
|
102
|
+
out = dataset.list_versions()
|
103
|
+
try:
|
104
|
+
next(iter(out))
|
105
|
+
except Exception as e:
|
106
|
+
if any(["CONN_DOES_NOT_EXIST" in _e for _e in e.args]):
|
107
|
+
raise Exception(
|
108
|
+
f"Dataset {dataset_assert_msg} does not exists. Please check datasets args"
|
109
|
+
)
|
110
|
+
else:
|
111
|
+
# caused by sdk failure
|
112
|
+
pass
|
113
|
+
# checking if model is evaluated with this dataset
|
114
|
+
_is_found = False
|
115
|
+
for each in list_eval_outputs:
|
116
|
+
if each.status.code == status_code_pb2.MODEL_EVALUATED:
|
117
|
+
eval_dataset = each.ground_truth_dataset
|
118
|
+
# if version_id is empty -> get latest eval result of dataset,app,user id
|
119
|
+
if (
|
120
|
+
dataset.app_id == eval_dataset.app_id
|
121
|
+
and dataset.id == eval_dataset.id
|
122
|
+
and dataset.user_id == eval_dataset.user_id
|
123
|
+
and (
|
124
|
+
not dataset.version.id or dataset.version.id == eval_dataset.version.id
|
125
|
+
)
|
126
|
+
):
|
127
|
+
# append to eval_data
|
128
|
+
self.eval_data.append(each)
|
129
|
+
_is_found = True
|
130
|
+
break
|
131
|
+
|
132
|
+
# if not evaluated, but user wants to proceed it
|
133
|
+
if not _is_found:
|
134
|
+
if attempt_evaluate:
|
135
|
+
self.eval_data.append(self.evaluate_and_wait(dataset, eval_info=eval_info))
|
136
|
+
# otherwise raise error
|
137
|
+
else:
|
138
|
+
raise Exception(
|
139
|
+
f"Model {self.model.model_info.name} in app {self.model.model_info.app_id} is not evaluated yet with dataset {dataset_assert_msg}"
|
140
|
+
)
|
141
|
+
|
142
|
+
@staticmethod
|
143
|
+
def proto_to_dict(value):
|
144
|
+
return MessageToDict(value, preserving_proto_field_name=True)
|
145
|
+
|
146
|
+
@staticmethod
|
147
|
+
def _f1(x: float, y: float):
|
148
|
+
z = x + y
|
149
|
+
return 2 * x * y / z if z else 0.0
|
150
|
+
|
151
|
+
def _get_eval(self, index=0, **kwargs):
|
152
|
+
logger.info(
|
153
|
+
f"Model {self.get_model_name(pretify=True)}: retrieving {kwargs} metrics of dataset: {self.get_dataset_name_by_index(index)}"
|
154
|
+
)
|
155
|
+
result = self.model.get_eval_by_id(eval_id=self.eval_data[index].id, **kwargs)
|
156
|
+
for k, v in kwargs.items():
|
157
|
+
if v:
|
158
|
+
getattr(self.eval_data[index], k).MergeFrom(getattr(result, k))
|
159
|
+
|
160
|
+
def get_eval_data(self, metric_name: str, index=0):
|
161
|
+
if metric_name == 'binary_metrics':
|
162
|
+
if len(self.eval_data[index].binary_metrics) == 0:
|
163
|
+
self._get_eval(index, binary_metrics=True)
|
164
|
+
elif metric_name == 'label_counts':
|
165
|
+
if self.proto_to_dict(self.eval_data[index].label_counts) == {}:
|
166
|
+
self._get_eval(index, label_counts=True)
|
167
|
+
elif metric_name == 'confusion_matrix':
|
168
|
+
if self.eval_data[index].confusion_matrix.ByteSize() == 0:
|
169
|
+
self._get_eval(index, confusion_matrix=True)
|
170
|
+
elif metric_name == 'metrics_by_class':
|
171
|
+
if len(self.eval_data[index].metrics_by_class) == 0:
|
172
|
+
self._get_eval(index, metrics_by_class=True)
|
173
|
+
elif metric_name == 'metrics_by_area':
|
174
|
+
if len(self.eval_data[index].metrics_by_area) == 0:
|
175
|
+
self._get_eval(index, metrics_by_area=True)
|
176
|
+
|
177
|
+
return getattr(self.eval_data[index], metric_name)
|
178
|
+
|
179
|
+
def get_threshold_index(self, threshold_list: list, selected_value: float = 0.5) -> int:
|
180
|
+
assert 0 <= selected_value <= 1 and isinstance(selected_value, float)
|
181
|
+
threshold_list = [round(each, 2) for each in threshold_list]
|
182
|
+
|
183
|
+
def parse_precision(x):
|
184
|
+
return len(str(x).split(".")[1])
|
185
|
+
|
186
|
+
precision = parse_precision(selected_value)
|
187
|
+
if precision > 2:
|
188
|
+
selected_value = round(selected_value, 2)
|
189
|
+
logger.warning("Round the selected value to .2 decimals")
|
190
|
+
return threshold_list.index(selected_value)
|
191
|
+
|
192
|
+
def get_dataset_name_by_index(self, index=0, pretify=True):
|
193
|
+
out = self.eval_data[index].ground_truth_dataset
|
194
|
+
if pretify:
|
195
|
+
app_id = out.app_id
|
196
|
+
dataset = out.id
|
197
|
+
# out = f"{app_id}/{dataset}/{ver[:5]}" if ver else f"{app_id}/{dataset}"
|
198
|
+
if self.model.model_info.app_id == app_id:
|
199
|
+
out = dataset
|
200
|
+
else:
|
201
|
+
out = f"{app_id}/{dataset}"
|
202
|
+
|
203
|
+
return out
|
204
|
+
|
205
|
+
def get_model_name(self, pretify=True):
|
206
|
+
model = self.model.model_info
|
207
|
+
if pretify:
|
208
|
+
app_id = model.app_id
|
209
|
+
name = model.id
|
210
|
+
ver = model.model_version.id
|
211
|
+
model = f"{app_id}/{name}/{ver[:5]}" if ver else f"{app_id}/{name}"
|
212
|
+
|
213
|
+
return model
|
214
|
+
|
215
|
+
def _process_curve(
|
216
|
+
self, data: resources_pb2.BinaryMetrics, metric_name: str, x: str, y: str
|
217
|
+
) -> Dict[str, Dict[str, np.array]]:
|
218
|
+
"""Postprocess curve"""
|
219
|
+
x_arr = []
|
220
|
+
y_arr = []
|
221
|
+
threshold = []
|
222
|
+
outputs = []
|
223
|
+
|
224
|
+
def _make_df(xcol, ycol, concept_col, th_col):
|
225
|
+
return pd.DataFrame({x: xcol, y: ycol, 'concept': concept_col, 'threshold': th_col})
|
226
|
+
|
227
|
+
for bd in data:
|
228
|
+
concept_id = bd.concept.id
|
229
|
+
metric = eval(f'bd.{metric_name}')
|
230
|
+
if metric.ByteSize() == 0:
|
231
|
+
continue
|
232
|
+
_x = np.array(eval(f'metric.{x}'))
|
233
|
+
_y = np.array(eval(f'metric.{y}'))
|
234
|
+
threshold = np.array(metric.thresholds)
|
235
|
+
x_arr.append(_x)
|
236
|
+
y_arr.append(_y)
|
237
|
+
concept_cols = [concept_id for _ in range(len(_x))]
|
238
|
+
outputs.append(_make_df(_x, _y, concept_cols, threshold))
|
239
|
+
|
240
|
+
avg_x = np.mean(x_arr, axis=0)
|
241
|
+
avg_y = np.mean(y_arr, axis=0)
|
242
|
+
if np.isnan(avg_x).all():
|
243
|
+
return None
|
128
244
|
else:
|
129
|
-
|
130
|
-
|
131
|
-
)
|
132
|
-
|
133
|
-
@staticmethod
|
134
|
-
def proto_to_dict(value):
|
135
|
-
return MessageToDict(value, preserving_proto_field_name=True)
|
136
|
-
|
137
|
-
@staticmethod
|
138
|
-
def _f1(x: float, y: float):
|
139
|
-
z = x + y
|
140
|
-
return 2 * x * y / z if z else 0.
|
141
|
-
|
142
|
-
def _get_eval(self, index=0, **kwargs):
|
143
|
-
logger.info(
|
144
|
-
f"Model {self.get_model_name(pretify=True)}: retrieving {kwargs} metrics of dataset: {self.get_dataset_name_by_index(index)}"
|
145
|
-
)
|
146
|
-
result = self.model.get_eval_by_id(eval_id=self.eval_data[index].id, **kwargs)
|
147
|
-
for k, v in kwargs.items():
|
148
|
-
if v:
|
149
|
-
getattr(self.eval_data[index], k).MergeFrom(getattr(result, k))
|
150
|
-
|
151
|
-
def get_eval_data(self, metric_name: str, index=0):
|
152
|
-
if metric_name == 'binary_metrics':
|
153
|
-
if len(self.eval_data[index].binary_metrics) == 0:
|
154
|
-
self._get_eval(index, binary_metrics=True)
|
155
|
-
elif metric_name == 'label_counts':
|
156
|
-
if self.proto_to_dict(self.eval_data[index].label_counts) == {}:
|
157
|
-
self._get_eval(index, label_counts=True)
|
158
|
-
elif metric_name == 'confusion_matrix':
|
159
|
-
if self.eval_data[index].confusion_matrix.ByteSize() == 0:
|
160
|
-
self._get_eval(index, confusion_matrix=True)
|
161
|
-
elif metric_name == 'metrics_by_class':
|
162
|
-
if len(self.eval_data[index].metrics_by_class) == 0:
|
163
|
-
self._get_eval(index, metrics_by_class=True)
|
164
|
-
elif metric_name == 'metrics_by_area':
|
165
|
-
if len(self.eval_data[index].metrics_by_area) == 0:
|
166
|
-
self._get_eval(index, metrics_by_area=True)
|
167
|
-
|
168
|
-
return getattr(self.eval_data[index], metric_name)
|
169
|
-
|
170
|
-
def get_threshold_index(self, threshold_list: list, selected_value: float = 0.5) -> int:
|
171
|
-
assert 0 <= selected_value <= 1 and isinstance(selected_value, float)
|
172
|
-
threshold_list = [round(each, 2) for each in threshold_list]
|
173
|
-
|
174
|
-
def parse_precision(x):
|
175
|
-
return len(str(x).split(".")[1])
|
176
|
-
|
177
|
-
precision = parse_precision(selected_value)
|
178
|
-
if precision > 2:
|
179
|
-
selected_value = round(selected_value, 2)
|
180
|
-
logger.warning("Round the selected value to .2 decimals")
|
181
|
-
return threshold_list.index(selected_value)
|
182
|
-
|
183
|
-
def get_dataset_name_by_index(self, index=0, pretify=True):
|
184
|
-
out = self.eval_data[index].ground_truth_dataset
|
185
|
-
if pretify:
|
186
|
-
app_id = out.app_id
|
187
|
-
dataset = out.id
|
188
|
-
#out = f"{app_id}/{dataset}/{ver[:5]}" if ver else f"{app_id}/{dataset}"
|
189
|
-
if self.model.model_info.app_id == app_id:
|
190
|
-
out = dataset
|
191
|
-
else:
|
192
|
-
out = f"{app_id}/{dataset}"
|
193
|
-
|
194
|
-
return out
|
195
|
-
|
196
|
-
def get_model_name(self, pretify=True):
|
197
|
-
model = self.model.model_info
|
198
|
-
if pretify:
|
199
|
-
app_id = model.app_id
|
200
|
-
name = model.id
|
201
|
-
ver = model.model_version.id
|
202
|
-
model = f"{app_id}/{name}/{ver[:5]}" if ver else f"{app_id}/{name}"
|
203
|
-
|
204
|
-
return model
|
205
|
-
|
206
|
-
def _process_curve(self, data: resources_pb2.BinaryMetrics, metric_name: str, x: str,
|
207
|
-
y: str) -> Dict[str, Dict[str, np.array]]:
|
208
|
-
""" Postprocess curve
|
209
|
-
"""
|
210
|
-
x_arr = []
|
211
|
-
y_arr = []
|
212
|
-
threshold = []
|
213
|
-
outputs = []
|
214
|
-
|
215
|
-
def _make_df(xcol, ycol, concept_col, th_col):
|
216
|
-
return pd.DataFrame({x: xcol, y: ycol, 'concept': concept_col, 'threshold': th_col})
|
217
|
-
|
218
|
-
for bd in data:
|
219
|
-
concept_id = bd.concept.id
|
220
|
-
metric = eval(f'bd.{metric_name}')
|
221
|
-
if metric.ByteSize() == 0:
|
222
|
-
continue
|
223
|
-
_x = np.array(eval(f'metric.{x}'))
|
224
|
-
_y = np.array(eval(f'metric.{y}'))
|
225
|
-
threshold = np.array(metric.thresholds)
|
226
|
-
x_arr.append(_x)
|
227
|
-
y_arr.append(_y)
|
228
|
-
concept_cols = [concept_id for _ in range(len(_x))]
|
229
|
-
outputs.append(_make_df(_x, _y, concept_cols, threshold))
|
230
|
-
|
231
|
-
avg_x = np.mean(x_arr, axis=0)
|
232
|
-
avg_y = np.mean(y_arr, axis=0)
|
233
|
-
if np.isnan(avg_x).all():
|
234
|
-
return None
|
235
|
-
else:
|
236
|
-
avg_cols = [MACRO_AVG for _ in range(len(avg_x))]
|
237
|
-
outputs.append(_make_df(avg_x, avg_y, avg_cols, threshold))
|
245
|
+
avg_cols = [MACRO_AVG for _ in range(len(avg_x))]
|
246
|
+
outputs.append(_make_df(avg_x, avg_y, avg_cols, threshold))
|
238
247
|
|
239
|
-
|
248
|
+
return pd.concat(outputs, axis=0)
|
240
249
|
|
241
|
-
|
242
|
-
|
250
|
+
def parse_concept_ids(self, *args, **kwargs) -> List[str]:
|
251
|
+
raise NotImplementedError
|
243
252
|
|
244
|
-
|
245
|
-
|
253
|
+
def detailed_summary(self, *args, **kwargs):
|
254
|
+
raise NotImplementedError
|
246
255
|
|
247
|
-
|
248
|
-
|
256
|
+
def pr_curve(self, *args, **kwargs):
|
257
|
+
raise NotImplementedError
|
249
258
|
|
250
|
-
|
251
|
-
|
259
|
+
def roc_curve(self, *args, **kwargs):
|
260
|
+
raise NotImplementedError
|
252
261
|
|
253
|
-
|
254
|
-
|
262
|
+
def confusion_matrix(self, *args, **kwargs):
|
263
|
+
raise NotImplementedError
|
255
264
|
|
256
265
|
|
257
266
|
@dataclass
|
258
267
|
class PlaceholderHandler(_BaseEvalResultHandler):
|
268
|
+
def parse_concept_ids(self, *args, **kwargs) -> List[str]:
|
269
|
+
return None
|
259
270
|
|
260
|
-
|
261
|
-
|
271
|
+
def detailed_summary(self, *args, **kwargs):
|
272
|
+
return None
|
262
273
|
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
def pr_curve(self, *args, **kwargs):
|
267
|
-
return None
|
274
|
+
def pr_curve(self, *args, **kwargs):
|
275
|
+
return None
|
268
276
|
|
269
277
|
|
270
278
|
@dataclass
|
271
279
|
class ClassificationResultHandler(_BaseEvalResultHandler):
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
[
|
280
|
+
def parse_concept_ids(self, index=0) -> List[str]:
|
281
|
+
eval_data = self.get_eval_data(metric_name='label_counts', index=index)
|
282
|
+
concept_ids = [temp.concept.id for temp in eval_data.positive_label_counts]
|
283
|
+
return concept_ids
|
284
|
+
|
285
|
+
def detailed_summary(
|
286
|
+
self, index=0, confidence_threshold: float = 0.5, **kwargs
|
287
|
+
) -> Union[None, Tuple[pd.DataFrame, pd.DataFrame]]:
|
288
|
+
"""Making detailed table per concept and for total concept
|
289
|
+
|
290
|
+
Args:
|
291
|
+
index (int, optional): Index of eval dataset. Defaults to 0.
|
292
|
+
confidence_threshold (float, optional): confidence threshold. Defaults to 0.5.
|
293
|
+
|
294
|
+
Returns:
|
295
|
+
tuple: concepts dataframe, total dataframe
|
296
|
+
"""
|
297
|
+
eval_data = self.get_eval_data('binary_metrics', index=index)
|
298
|
+
summary = self.get_eval_data('summary', index=index)
|
299
|
+
|
300
|
+
total_labeled = 0
|
301
|
+
total_predicted = 0
|
302
|
+
total_tp = 0
|
303
|
+
total_fn = 0
|
304
|
+
total_fp = 0
|
305
|
+
metrics = []
|
306
|
+
|
307
|
+
for bd in eval_data:
|
308
|
+
concept_id = bd.concept.id
|
309
|
+
if bd.precision_recall_curve.ByteSize() == 0:
|
310
|
+
continue
|
311
|
+
pr_th_index = self.get_threshold_index(
|
312
|
+
list(bd.precision_recall_curve.thresholds), selected_value=confidence_threshold
|
313
|
+
)
|
314
|
+
roc_th_index = self.get_threshold_index(
|
315
|
+
list(bd.roc_curve.thresholds), selected_value=confidence_threshold
|
316
|
+
)
|
317
|
+
if pr_th_index is None or roc_th_index is None:
|
318
|
+
continue
|
319
|
+
num_pos_labeled = bd.num_pos
|
320
|
+
num_neg_labeled = bd.num_neg
|
321
|
+
# TP/(TP+FP)
|
322
|
+
precision = bd.precision_recall_curve.precision[pr_th_index]
|
323
|
+
# TP/(TP+FN)
|
324
|
+
recall = bd.precision_recall_curve.recall[pr_th_index]
|
325
|
+
# FP/(FP+TN)
|
326
|
+
fpr = bd.roc_curve.fpr[roc_th_index]
|
327
|
+
# TP/(TP+FN)
|
328
|
+
tpr = bd.roc_curve.tpr[roc_th_index]
|
329
|
+
# TP+FN
|
330
|
+
tp = int(tpr * num_pos_labeled)
|
331
|
+
fn = num_pos_labeled - tp
|
332
|
+
fp = int(fpr * num_neg_labeled)
|
333
|
+
num_pos_pred = tp + fp
|
334
|
+
f1 = self._f1(recall, precision)
|
335
|
+
|
336
|
+
total_labeled += num_pos_labeled
|
337
|
+
total_predicted += num_pos_pred
|
338
|
+
total_fn += fn
|
339
|
+
total_tp += tp
|
340
|
+
total_fp += fp
|
341
|
+
# roc auc, total labelled, predicted, tp, fn, fp, recall, precision, f1
|
342
|
+
_d = OrderedDict(
|
343
|
+
{
|
344
|
+
"Concept": concept_id,
|
345
|
+
"Accuracy (ROC AUC)": round(bd.roc_auc, 3),
|
346
|
+
"Total Labeled": num_pos_labeled,
|
347
|
+
"Total Predicted": num_pos_pred,
|
348
|
+
"True Positives": tp,
|
349
|
+
"False Negatives": fn,
|
350
|
+
"False Positives": fp,
|
351
|
+
"Recall": recall,
|
352
|
+
"Precision": precision,
|
353
|
+
"F1": f1,
|
354
|
+
}
|
355
|
+
)
|
356
|
+
metrics.append(pd.DataFrame(_d, index=[0]))
|
357
|
+
|
358
|
+
# If no valid data is found, return None
|
359
|
+
if not metrics:
|
360
|
+
return None
|
361
|
+
# Make per concept df
|
362
|
+
df = pd.concat(metrics, axis=0)
|
363
|
+
# Make total df
|
364
|
+
sum_df_total = sum(df["Total Labeled"])
|
365
|
+
precision = sum(df.Precision * df["Total Labeled"]) / sum_df_total if sum_df_total else 0.0
|
366
|
+
recall = sum(df.Recall * df["Total Labeled"]) / sum_df_total if sum_df_total else 0.0
|
367
|
+
f1 = self._f1(recall, precision)
|
368
|
+
df_total = pd.DataFrame(
|
358
369
|
[
|
359
|
-
|
360
|
-
|
370
|
+
[
|
371
|
+
'Total',
|
372
|
+
summary.macro_avg_roc_auc,
|
373
|
+
total_labeled,
|
374
|
+
total_predicted,
|
375
|
+
total_tp,
|
376
|
+
total_fn,
|
377
|
+
total_fp,
|
378
|
+
recall,
|
379
|
+
precision,
|
380
|
+
f1,
|
381
|
+
],
|
361
382
|
],
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
383
|
+
columns=df.columns,
|
384
|
+
index=[0],
|
385
|
+
)
|
386
|
+
|
387
|
+
return df, df_total
|
388
|
+
|
389
|
+
def pr_curve(self, index=0, **kwargs) -> Union[None, pd.DataFrame]:
|
390
|
+
"""Making PR curve
|
391
|
+
|
392
|
+
Args:
|
393
|
+
index (int, optional): Index of eval dataset. Defaults to 0.
|
394
|
+
|
395
|
+
Returns:
|
396
|
+
dictionary: Keys are concept ids and 'macro_avg'. Values are dictionaries of {precision: np.array, recall: np.array}
|
397
|
+
"""
|
398
|
+
eval_data = self.get_eval_data(metric_name='binary_metrics', index=index)
|
399
|
+
outputs = self._process_curve(
|
400
|
+
eval_data, metric_name='precision_recall_curve', x='recall', y='precision'
|
401
|
+
)
|
402
|
+
return outputs
|
403
|
+
|
404
|
+
def roc_curve(self, index=0, **kwargs) -> Union[None, pd.DataFrame]:
|
405
|
+
eval_data = self.get_eval_data(metric_name='binary_metrics', index=index)
|
406
|
+
outputs = self._process_curve(eval_data, metric_name='roc_curve', x='tpr', y='fpr')
|
407
|
+
return outputs
|
408
|
+
|
409
|
+
def confusion_matrix(self, index=0, **kwargs):
|
410
|
+
eval_data = self.get_eval_data(metric_name='confusion_matrix', index=index)
|
411
|
+
concept_ids = self.parse_concept_ids(index)
|
412
|
+
concept_ids.sort()
|
413
|
+
data = np.zeros((len(concept_ids), len(concept_ids)), np.float32)
|
414
|
+
for entry in eval_data.matrix:
|
415
|
+
p = entry.predicted_concept.id
|
416
|
+
a = entry.actual_concept.id
|
417
|
+
if p in concept_ids and a in concept_ids:
|
418
|
+
data[concept_ids.index(a), concept_ids.index(p)] = np.around(
|
419
|
+
entry.value, decimals=3
|
420
|
+
)
|
421
|
+
else:
|
422
|
+
continue
|
423
|
+
rownames = pd.MultiIndex.from_arrays([concept_ids], names=['Actual'])
|
424
|
+
colnames = pd.MultiIndex.from_arrays([concept_ids], names=['Predicted'])
|
425
|
+
df = pd.DataFrame(data, columns=colnames, index=rownames)
|
426
|
+
|
427
|
+
return df
|
404
428
|
|
405
429
|
|
406
430
|
@dataclass
|
407
431
|
class DetectionResultHandler(_BaseEvalResultHandler):
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
df_total["
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
432
|
+
AREA_LIST = ["all", "medium", "small"]
|
433
|
+
IOU_LIST = list(np.arange(0.5, 1.0, 0.1))
|
434
|
+
|
435
|
+
def parse_concept_ids(self, index=0) -> List[str]:
|
436
|
+
eval_data = self.get_eval_data(metric_name='metrics_by_class', index=index)
|
437
|
+
concept_ids = [temp.concept.id for temp in eval_data]
|
438
|
+
return concept_ids
|
439
|
+
|
440
|
+
def detailed_summary(
|
441
|
+
self,
|
442
|
+
index=0,
|
443
|
+
confidence_threshold: float = 0.5,
|
444
|
+
iou_threshold: float = 0.5,
|
445
|
+
area: str = "all",
|
446
|
+
bypass_const: bool = False,
|
447
|
+
**kwargs,
|
448
|
+
):
|
449
|
+
if not bypass_const:
|
450
|
+
assert iou_threshold in self.IOU_LIST, (
|
451
|
+
f"Expected iou_threshold in {self.IOU_LIST}, got {iou_threshold}"
|
452
|
+
)
|
453
|
+
assert area in self.AREA_LIST, f"Expected area in {self.AREA_LIST}, got {area}"
|
454
|
+
|
455
|
+
eval_data = self.get_eval_data('metrics_by_class', index=index)
|
456
|
+
# summary = self.get_eval_data('summary', index=index)
|
457
|
+
metrics = []
|
458
|
+
for bd in eval_data:
|
459
|
+
# total label
|
460
|
+
_iou = round(bd.iou, 1)
|
461
|
+
if not (area and bd.area_name == area) or not (
|
462
|
+
iou_threshold and iou_threshold == _iou
|
463
|
+
):
|
464
|
+
continue
|
465
|
+
concept_id = bd.concept.id
|
466
|
+
total = round(bd.num_tot, 3)
|
467
|
+
# TP / (TP + FP)
|
468
|
+
if len(bd.precision_recall_curve.precision) > 0:
|
469
|
+
pr_th_index = self.get_threshold_index(
|
470
|
+
list(bd.precision_recall_curve.thresholds), selected_value=confidence_threshold
|
471
|
+
)
|
472
|
+
p = round(bd.precision_recall_curve.precision[pr_th_index], 3)
|
473
|
+
else:
|
474
|
+
p = 0
|
475
|
+
# TP / (TP + FN)
|
476
|
+
if len(bd.precision_recall_curve.recall) > 0:
|
477
|
+
pr_th_index = self.get_threshold_index(
|
478
|
+
list(bd.precision_recall_curve.thresholds), selected_value=confidence_threshold
|
479
|
+
)
|
480
|
+
r = round(bd.precision_recall_curve.recall[pr_th_index], 3)
|
481
|
+
else:
|
482
|
+
r = 0
|
483
|
+
tp = int(round(r * total, 0))
|
484
|
+
fn = total - tp
|
485
|
+
fp = float(tp) / p - tp if p else 0
|
486
|
+
fp = int(round(fp, 1))
|
487
|
+
f1 = self._f1(r, p)
|
488
|
+
_d = {
|
489
|
+
"Concept": concept_id,
|
490
|
+
"Average Precision": round(float(bd.avg_precision), 3),
|
491
|
+
"Total Labeled": total,
|
492
|
+
"True Positives": tp,
|
493
|
+
"False Positives": fp,
|
494
|
+
"False Negatives": fn,
|
495
|
+
"Recall": r,
|
496
|
+
"Precision": p,
|
497
|
+
"F1": f1,
|
498
|
+
}
|
499
|
+
metrics.append(pd.DataFrame(_d, index=[0]))
|
500
|
+
|
501
|
+
if not metrics:
|
502
|
+
return None
|
503
|
+
|
504
|
+
df = pd.concat(metrics, axis=0)
|
505
|
+
df_total = defaultdict()
|
506
|
+
sum_df_total = df["Total Labeled"].sum()
|
507
|
+
df_total["Concept"] = "Total"
|
508
|
+
df_total["Average Precision"] = df["Average Precision"].mean()
|
509
|
+
df_total["Total Labeled"] = sum_df_total
|
510
|
+
df_total["True Positives"] = df["True Positives"].sum()
|
511
|
+
df_total["False Positives"] = df["False Positives"].sum()
|
512
|
+
df_total["False Negatives"] = df["False Negatives"].sum()
|
513
|
+
df_total["Recall"] = (
|
514
|
+
sum(df.Recall * df["Total Labeled"]) / sum_df_total if sum_df_total else 0.0
|
515
|
+
)
|
516
|
+
df_total["Precision"] = (
|
517
|
+
df_total["True Positives"] / (df_total["True Positives"] + df_total["False Positives"])
|
518
|
+
if sum_df_total
|
519
|
+
else 0.0
|
520
|
+
)
|
521
|
+
df_total["F1"] = self._f1(df_total["Recall"], df_total["Precision"])
|
522
|
+
df_total = pd.DataFrame(df_total, index=[0])
|
523
|
+
|
524
|
+
return [df, df_total]
|
525
|
+
|
526
|
+
def pr_curve(
|
527
|
+
self, index=0, iou_threshold: float = 0.5, area: str = "all", bypass_const=False, **kwargs
|
528
|
+
):
|
529
|
+
if not bypass_const:
|
530
|
+
assert iou_threshold in self.IOU_LIST, (
|
531
|
+
f"Expected iou_threshold in {self.IOU_LIST}, got {iou_threshold}"
|
532
|
+
)
|
533
|
+
assert area in self.AREA_LIST, f"Expected area in {self.AREA_LIST}, got {area}"
|
534
|
+
|
535
|
+
eval_data = self.get_eval_data(metric_name='metrics_by_class', index=index)
|
536
|
+
_valid_eval_data = []
|
537
|
+
for bd in eval_data:
|
538
|
+
_iou = round(bd.iou, 1)
|
539
|
+
if not (area and bd.area_name == area) or not (
|
540
|
+
iou_threshold and iou_threshold == _iou
|
541
|
+
):
|
542
|
+
continue
|
543
|
+
_valid_eval_data.append(bd)
|
544
|
+
|
545
|
+
outputs = self._process_curve(
|
546
|
+
_valid_eval_data, metric_name='precision_recall_curve', x='recall', y='precision'
|
547
|
+
)
|
548
|
+
return outputs
|
549
|
+
|
550
|
+
def roc_curve(self, index=0, **kwargs) -> None:
|
551
|
+
return None
|
552
|
+
|
553
|
+
def confusion_matrix(self, index=0, **kwargs) -> None:
|
554
|
+
return None
|
518
555
|
|
519
556
|
|
520
557
|
def make_handler_by_type(model_type: str) -> _BaseEvalResultHandler:
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
558
|
+
_eval_type = get_eval_type(model_type)
|
559
|
+
if _eval_type == EvalType.CLASSIFICATION:
|
560
|
+
return ClassificationResultHandler
|
561
|
+
elif _eval_type == EvalType.DETECTION:
|
562
|
+
return DetectionResultHandler
|
563
|
+
else:
|
564
|
+
return PlaceholderHandler
|