clarifai 10.2.1__py3-none-any.whl → 10.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- clarifai/client/app.py +15 -7
- clarifai/client/auth/helper.py +12 -2
- clarifai/client/base.py +14 -4
- clarifai/client/dataset.py +15 -2
- clarifai/client/input.py +14 -1
- clarifai/client/model.py +198 -19
- clarifai/client/module.py +9 -1
- clarifai/client/search.py +10 -2
- clarifai/client/user.py +22 -14
- clarifai/client/workflow.py +9 -1
- clarifai/constants/input.py +1 -0
- clarifai/utils/evaluation/__init__.py +2 -426
- clarifai/utils/evaluation/main.py +426 -0
- clarifai/utils/evaluation/testset_annotation_parser.py +150 -0
- clarifai/versions.py +1 -1
- {clarifai-10.2.1.dist-info → clarifai-10.3.0.dist-info}/METADATA +12 -12
- {clarifai-10.2.1.dist-info → clarifai-10.3.0.dist-info}/RECORD +21 -22
- {clarifai-10.2.1.dist-info → clarifai-10.3.0.dist-info}/WHEEL +1 -1
- clarifai/client/runner.py +0 -234
- clarifai/runners/__init__.py +0 -0
- clarifai/runners/example.py +0 -40
- clarifai/runners/example_llama2.py +0 -81
- {clarifai-10.2.1.dist-info → clarifai-10.3.0.dist-info}/LICENSE +0 -0
- {clarifai-10.2.1.dist-info → clarifai-10.3.0.dist-info}/entry_points.txt +0 -0
- {clarifai-10.2.1.dist-info → clarifai-10.3.0.dist-info}/top_level.txt +0 -0
clarifai/client/user.py
CHANGED
@@ -7,7 +7,6 @@ from google.protobuf.json_format import MessageToDict
|
|
7
7
|
from clarifai.client.app import App
|
8
8
|
from clarifai.client.base import BaseClient
|
9
9
|
from clarifai.client.lister import Lister
|
10
|
-
from clarifai.client.runner import Runner
|
11
10
|
from clarifai.errors import UserError
|
12
11
|
from clarifai.utils.logging import get_logger
|
13
12
|
|
@@ -20,6 +19,7 @@ class User(Lister, BaseClient):
|
|
20
19
|
base_url: str = "https://api.clarifai.com",
|
21
20
|
pat: str = None,
|
22
21
|
token: str = None,
|
22
|
+
root_certificates_path: str = None,
|
23
23
|
**kwargs):
|
24
24
|
"""Initializes an User object.
|
25
25
|
|
@@ -28,12 +28,20 @@ class User(Lister, BaseClient):
|
|
28
28
|
base_url (str): Base API url. Default "https://api.clarifai.com"
|
29
29
|
pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
|
30
30
|
token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
|
31
|
+
root_certificates_path (str): Path to the SSL root certificates file, used to establish secure gRPC connections.
|
31
32
|
**kwargs: Additional keyword arguments to be passed to the User.
|
32
33
|
"""
|
33
34
|
self.kwargs = {**kwargs, 'id': user_id}
|
34
35
|
self.user_info = resources_pb2.User(**self.kwargs)
|
35
36
|
self.logger = get_logger(logger_level="INFO", name=__name__)
|
36
|
-
BaseClient.__init__(
|
37
|
+
BaseClient.__init__(
|
38
|
+
self,
|
39
|
+
user_id=self.id,
|
40
|
+
app_id="",
|
41
|
+
base=base_url,
|
42
|
+
pat=pat,
|
43
|
+
token=token,
|
44
|
+
root_certificates_path=root_certificates_path)
|
37
45
|
Lister.__init__(self)
|
38
46
|
|
39
47
|
def list_apps(self, filter_by: Dict[str, Any] = {}, page_no: int = None,
|
@@ -69,7 +77,7 @@ class User(Lister, BaseClient):
|
|
69
77
|
**app_info) #(base_url=self.base, pat=self.pat, token=self.token, **app_info)
|
70
78
|
|
71
79
|
def list_runners(self, filter_by: Dict[str, Any] = {}, page_no: int = None,
|
72
|
-
per_page: int = None) -> Generator[
|
80
|
+
per_page: int = None) -> Generator[dict, None, None]:
|
73
81
|
"""List all runners for the user
|
74
82
|
|
75
83
|
Args:
|
@@ -78,7 +86,7 @@ class User(Lister, BaseClient):
|
|
78
86
|
per_page (int): The number of items per page.
|
79
87
|
|
80
88
|
Yields:
|
81
|
-
|
89
|
+
Dict: Dictionaries containing information about the runners.
|
82
90
|
|
83
91
|
Example:
|
84
92
|
>>> from clarifai.client.user import User
|
@@ -98,8 +106,7 @@ class User(Lister, BaseClient):
|
|
98
106
|
page_no=page_no)
|
99
107
|
|
100
108
|
for runner_info in all_runners_info:
|
101
|
-
yield
|
102
|
-
auth=self.auth_helper, check_runner_exists=False, **runner_info)
|
109
|
+
yield dict(auth=self.auth_helper, check_runner_exists=False, **runner_info)
|
103
110
|
|
104
111
|
def create_app(self, app_id: str, base_workflow: str = 'Empty', **kwargs) -> App:
|
105
112
|
"""Creates an app for the user.
|
@@ -127,7 +134,7 @@ class User(Lister, BaseClient):
|
|
127
134
|
self.logger.info("\nApp created\n%s", response.status)
|
128
135
|
return App.from_auth_helper(auth=self.auth_helper, app_id=app_id)
|
129
136
|
|
130
|
-
def create_runner(self, runner_id: str, labels: List[str], description: str) ->
|
137
|
+
def create_runner(self, runner_id: str, labels: List[str], description: str) -> dict:
|
131
138
|
"""Create a runner
|
132
139
|
|
133
140
|
Args:
|
@@ -136,13 +143,14 @@ class User(Lister, BaseClient):
|
|
136
143
|
description (str): Description of Runner
|
137
144
|
|
138
145
|
Returns:
|
139
|
-
|
146
|
+
Dict: A dictionary containing information about the specified Runner ID.
|
140
147
|
|
141
148
|
Example:
|
142
149
|
>>> from clarifai.client.user import User
|
143
150
|
>>> client = User(user_id="user_id")
|
144
|
-
>>>
|
151
|
+
>>> runner_info = client.create_runner(runner_id="runner_id", labels=["label to link runner"], description="laptop runner")
|
145
152
|
"""
|
153
|
+
|
146
154
|
if not isinstance(labels, List):
|
147
155
|
raise UserError("Labels must be a List of strings")
|
148
156
|
|
@@ -155,7 +163,7 @@ class User(Lister, BaseClient):
|
|
155
163
|
raise Exception(response.status)
|
156
164
|
self.logger.info("\nRunner created\n%s", response.status)
|
157
165
|
|
158
|
-
return
|
166
|
+
return dict(
|
159
167
|
auth=self.auth_helper,
|
160
168
|
runner_id=runner_id,
|
161
169
|
user_id=self.id,
|
@@ -186,19 +194,19 @@ class User(Lister, BaseClient):
|
|
186
194
|
kwargs['user_id'] = self.id
|
187
195
|
return App.from_auth_helper(auth=self.auth_helper, app_id=app_id, **kwargs)
|
188
196
|
|
189
|
-
def runner(self, runner_id: str) ->
|
197
|
+
def runner(self, runner_id: str) -> dict:
|
190
198
|
"""Returns a Runner object if exists.
|
191
199
|
|
192
200
|
Args:
|
193
201
|
runner_id (str): The runner ID to interact with
|
194
202
|
|
195
203
|
Returns:
|
196
|
-
|
204
|
+
Dict: A dictionary containing information about the existing runner ID.
|
197
205
|
|
198
206
|
Example:
|
199
207
|
>>> from clarifai.client.user import User
|
200
208
|
>>> client = User(user_id="user_id")
|
201
|
-
>>>
|
209
|
+
>>> runner_info = client.runner(runner_id="runner_id")
|
202
210
|
"""
|
203
211
|
request = service_pb2.GetRunnerRequest(user_app_id=self.user_app_id, runner_id=runner_id)
|
204
212
|
response = self._grpc_request(self.STUB.GetRunner, request)
|
@@ -212,7 +220,7 @@ class User(Lister, BaseClient):
|
|
212
220
|
kwargs = self.process_response_keys(dict_response[list(dict_response.keys())[1]],
|
213
221
|
list(dict_response.keys())[1])
|
214
222
|
|
215
|
-
return
|
223
|
+
return dict(self.auth_helper, check_runner_exists=False, **kwargs)
|
216
224
|
|
217
225
|
def delete_app(self, app_id: str) -> None:
|
218
226
|
"""Deletes an app for the user.
|
clarifai/client/workflow.py
CHANGED
@@ -28,6 +28,7 @@ class Workflow(Lister, BaseClient):
|
|
28
28
|
base_url: str = "https://api.clarifai.com",
|
29
29
|
pat: str = None,
|
30
30
|
token: str = None,
|
31
|
+
root_certificates_path: str = None,
|
31
32
|
**kwargs):
|
32
33
|
"""Initializes a Workflow object.
|
33
34
|
|
@@ -43,6 +44,7 @@ class Workflow(Lister, BaseClient):
|
|
43
44
|
base_url (str): Base API url. Default "https://api.clarifai.com"
|
44
45
|
pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
|
45
46
|
token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
|
47
|
+
root_certificates_path (str): Path to the SSL root certificates file, used to establish secure gRPC connections.
|
46
48
|
**kwargs: Additional keyword arguments to be passed to the Workflow.
|
47
49
|
"""
|
48
50
|
if url and workflow_id:
|
@@ -59,7 +61,13 @@ class Workflow(Lister, BaseClient):
|
|
59
61
|
self.workflow_info = resources_pb2.Workflow(**self.kwargs)
|
60
62
|
self.logger = get_logger(logger_level="INFO", name=__name__)
|
61
63
|
BaseClient.__init__(
|
62
|
-
self,
|
64
|
+
self,
|
65
|
+
user_id=self.user_id,
|
66
|
+
app_id=self.app_id,
|
67
|
+
base=base_url,
|
68
|
+
pat=pat,
|
69
|
+
token=token,
|
70
|
+
root_certificates_path=root_certificates_path)
|
63
71
|
Lister.__init__(self)
|
64
72
|
|
65
73
|
def predict(self, inputs: List[Input], workflow_state_id: str = None):
|
@@ -0,0 +1 @@
|
|
1
|
+
MAX_UPLOAD_BATCH_SIZE = 128
|
@@ -1,427 +1,3 @@
|
|
1
|
-
import
|
2
|
-
from enum import Enum
|
3
|
-
from typing import List, Tuple, Union
|
1
|
+
from .main import EvalResultCompare
|
4
2
|
|
5
|
-
|
6
|
-
from clarifai.client.model import Model
|
7
|
-
|
8
|
-
from .helpers import (MACRO_AVG, EvalType, _BaseEvalResultHandler, get_eval_type,
|
9
|
-
make_handler_by_type)
|
10
|
-
|
11
|
-
try:
|
12
|
-
import seaborn as sns
|
13
|
-
except ImportError:
|
14
|
-
raise ImportError("Can not import seaborn. Please run `pip install seaborn` to install it")
|
15
|
-
|
16
|
-
try:
|
17
|
-
import matplotlib.pyplot as plt
|
18
|
-
except ImportError:
|
19
|
-
raise ImportError("Can not import matplotlib. Please run `pip install matplotlib` to install it")
|
20
|
-
|
21
|
-
try:
|
22
|
-
import pandas as pd
|
23
|
-
except ImportError:
|
24
|
-
raise ImportError("Can not import pandas. Please run `pip install pandas` to install it")
|
25
|
-
|
26
|
-
try:
|
27
|
-
from loguru import logger
|
28
|
-
except ImportError:
|
29
|
-
from ..logging import get_logger
|
30
|
-
logger = get_logger(logger_level="INFO", name=__name__)
|
31
|
-
|
32
|
-
__all__ = ['EvalResultCompare']
|
33
|
-
|
34
|
-
|
35
|
-
class CompareMode(Enum):
|
36
|
-
MANY_MODELS_TO_ONE_DATA = 0
|
37
|
-
ONE_MODEL_TO_MANY_DATA = 1
|
38
|
-
|
39
|
-
|
40
|
-
class EvalResultCompare:
|
41
|
-
"""Compare evaluation result of models against datasets.
|
42
|
-
Note: The module will pick latest result on the datasets.
|
43
|
-
and models must be same model type
|
44
|
-
|
45
|
-
Args:
|
46
|
-
---
|
47
|
-
models (Union[List[Model], List[str]]): List of Model or urls of models.
|
48
|
-
datasets (Union[Dataset, List[Dataset], str, List[str]]): A single or List of Url or Dataset
|
49
|
-
attempt_evaluate (bool): Evaluate when model is not evaluated with the datasets.
|
50
|
-
auth_kwargs (dict): Additional auth keyword arguments to be passed to the Dataset and Model if using url(s)
|
51
|
-
"""
|
52
|
-
|
53
|
-
def __init__(self,
|
54
|
-
models: Union[List[Model], List[str]],
|
55
|
-
datasets: Union[Dataset, List[Dataset], str, List[str]],
|
56
|
-
attempt_evaluate: bool = False,
|
57
|
-
auth_kwargs: dict = {}):
|
58
|
-
assert isinstance(models, list), ValueError("Expected list")
|
59
|
-
|
60
|
-
if len(models) > 1:
|
61
|
-
self.mode = CompareMode.MANY_MODELS_TO_ONE_DATA
|
62
|
-
self.comparator = "Model"
|
63
|
-
assert isinstance(datasets, Dataset) or (
|
64
|
-
isinstance(datasets, list) and len(datasets) == 1
|
65
|
-
), f"When comparing multiple models, must provide only one `datasets`. However got {datasets}"
|
66
|
-
else:
|
67
|
-
self.mode = CompareMode.ONE_MODEL_TO_MANY_DATA
|
68
|
-
self.comparator = "Dataset"
|
69
|
-
|
70
|
-
# validate models
|
71
|
-
if all(map(lambda x: isinstance(x, str), models)):
|
72
|
-
models = [Model(each, **auth_kwargs) for each in models]
|
73
|
-
elif not all(map(lambda x: isinstance(x, Model), models)):
|
74
|
-
raise ValueError(
|
75
|
-
f"Expected all models are list of string or list of Model, got {[type(each) for each in models]}"
|
76
|
-
)
|
77
|
-
# validate datasets
|
78
|
-
if not isinstance(datasets, list):
|
79
|
-
datasets = [
|
80
|
-
datasets,
|
81
|
-
]
|
82
|
-
if all(map(lambda x: isinstance(x, str), datasets)):
|
83
|
-
datasets = [Dataset(each, **auth_kwargs) for each in datasets]
|
84
|
-
elif not all(map(lambda x: isinstance(x, Dataset), datasets)):
|
85
|
-
raise ValueError(
|
86
|
-
f"Expected datasets must be str, list of string or Dataset, list of Dataset, got {[type(each) for each in datasets]}"
|
87
|
-
)
|
88
|
-
# Validate models vs datasets together
|
89
|
-
self._eval_handlers: List[_BaseEvalResultHandler] = []
|
90
|
-
self.model_type = None
|
91
|
-
logger.info("Initializing models...")
|
92
|
-
for model in models:
|
93
|
-
model.load_info()
|
94
|
-
model_type = model.model_info.model_type_id
|
95
|
-
if not self.model_type:
|
96
|
-
self.model_type = model_type
|
97
|
-
else:
|
98
|
-
assert self.model_type == model_type, f"Can not compare when model types are different, {self.model_type} != {model_type}"
|
99
|
-
m = make_handler_by_type(model_type)(model=model)
|
100
|
-
logger.info(f"* {m.get_model_name(pretify=True)}")
|
101
|
-
m.find_eval_id(datasets=datasets, attempt_evaluate=attempt_evaluate)
|
102
|
-
self._eval_handlers.append(m)
|
103
|
-
|
104
|
-
@property
|
105
|
-
def eval_handlers(self):
|
106
|
-
return self._eval_handlers
|
107
|
-
|
108
|
-
def _loop_eval_handlers(self, func_name: str, **kwargs) -> Tuple[list, list]:
|
109
|
-
""" Run methods of `eval_handlers[...].model`
|
110
|
-
|
111
|
-
Args:
|
112
|
-
func_name (str): method name, see `_BaseEvalResultHandler` child classes
|
113
|
-
kwargs: keyword arguments of the method
|
114
|
-
|
115
|
-
Return:
|
116
|
-
tuple:
|
117
|
-
- list of outputs
|
118
|
-
- list of comparator names
|
119
|
-
|
120
|
-
"""
|
121
|
-
outs = []
|
122
|
-
comparators = []
|
123
|
-
logger.info(f'Running `{func_name}`')
|
124
|
-
for _, each in enumerate(self.eval_handlers):
|
125
|
-
for ds_index, _ in enumerate(each.eval_data):
|
126
|
-
func = eval(f'each.{func_name}')
|
127
|
-
out = func(index=ds_index, **kwargs)
|
128
|
-
|
129
|
-
if self.mode == CompareMode.MANY_MODELS_TO_ONE_DATA:
|
130
|
-
name = each.get_model_name(pretify=True)
|
131
|
-
else:
|
132
|
-
name = each.get_dataset_name_by_index(ds_index, pretify=True)
|
133
|
-
if out is None:
|
134
|
-
logger.warning(f'{self.comparator}:{name} does not have valid data for `{func_name}`')
|
135
|
-
continue
|
136
|
-
comparators.append(name)
|
137
|
-
outs.append(out)
|
138
|
-
|
139
|
-
# remove app_id if models a
|
140
|
-
if self.mode == CompareMode.MANY_MODELS_TO_ONE_DATA:
|
141
|
-
apps = set([comp.split('/')[0] for comp in comparators])
|
142
|
-
if len(apps) == 1:
|
143
|
-
comparators = ['/'.join(comp.split('/')[1:]) for comp in comparators]
|
144
|
-
|
145
|
-
if not outs:
|
146
|
-
logger.warning(f'Model type {self.model_type} does not support `{func_name}`')
|
147
|
-
|
148
|
-
return outs, comparators
|
149
|
-
|
150
|
-
def detailed_summary(self,
|
151
|
-
confidence_threshold: float = .5,
|
152
|
-
iou_threshold: float = .5,
|
153
|
-
area: str = "all",
|
154
|
-
bypass_const=False) -> Union[Tuple[pd.DataFrame, pd.DataFrame], None]:
|
155
|
-
"""
|
156
|
-
Retrieve and compute popular metrics of model.
|
157
|
-
|
158
|
-
Args:
|
159
|
-
confidence_threshold (float): confidence threshold, applicable for classification and detection. Default is 0.5
|
160
|
-
iou_threshold (float): iou threshold, support in range(0.5, 1., step=0.1) applicable for detection
|
161
|
-
area (float): size of area, support {all, small, medium}, applicable for detection
|
162
|
-
|
163
|
-
Return:
|
164
|
-
None or tuple of dataframe: df summary per concept and total concepts
|
165
|
-
|
166
|
-
"""
|
167
|
-
df = []
|
168
|
-
total = []
|
169
|
-
# loop over all eval_handlers/dataset and call its method
|
170
|
-
outs, comparators = self._loop_eval_handlers(
|
171
|
-
'detailed_summary',
|
172
|
-
confidence_threshold=confidence_threshold,
|
173
|
-
iou_threshold=iou_threshold,
|
174
|
-
area=area,
|
175
|
-
bypass_const=bypass_const)
|
176
|
-
for indx, out in enumerate(outs):
|
177
|
-
_df, _total = out
|
178
|
-
_df[self.comparator] = [comparators[indx] for _ in range(len(_df))]
|
179
|
-
_total['Concept'].replace(
|
180
|
-
to_replace=['Total'], value=f'{self.comparator}:{comparators[indx]}', inplace=True)
|
181
|
-
_total.rename({'Concept': 'Total Concept'}, axis=1, inplace=True)
|
182
|
-
df.append(_df)
|
183
|
-
total.append(_total)
|
184
|
-
|
185
|
-
if df:
|
186
|
-
df = pd.concat(df, axis=0)
|
187
|
-
total = pd.concat(total, axis=0)
|
188
|
-
return df, total
|
189
|
-
else:
|
190
|
-
return None
|
191
|
-
|
192
|
-
def confusion_matrix(self, show=True, save_path: str = None,
|
193
|
-
cm_kwargs: dict = {}) -> Union[pd.DataFrame, None]:
|
194
|
-
"""Return dataframe of confusion matrix
|
195
|
-
Args:
|
196
|
-
show (bool, optional): Show the chart. Defaults to True.
|
197
|
-
save_path (str): path to save rendered chart.
|
198
|
-
cm_kwargs (dict): keyword args of `eval_handler[...].model.cm_kwargs` method.
|
199
|
-
Returns:
|
200
|
-
None or pd.Dataframe, If models don't have confusion matrix, return None
|
201
|
-
"""
|
202
|
-
outs, comparators = self._loop_eval_handlers("confusion_matrix", **cm_kwargs)
|
203
|
-
all_dfs = []
|
204
|
-
for _, (df, anchor) in enumerate(zip(outs, comparators)):
|
205
|
-
df[self.comparator] = [anchor for _ in range(len(df))]
|
206
|
-
all_dfs.append(df)
|
207
|
-
|
208
|
-
if all_dfs:
|
209
|
-
all_dfs = pd.concat(all_dfs, axis=0)
|
210
|
-
if save_path or show:
|
211
|
-
|
212
|
-
def _facet_heatmap(data, **kws):
|
213
|
-
data = data.dropna(axis=1)
|
214
|
-
data = data.drop(self.comparator, axis=1)
|
215
|
-
concepts = data.columns
|
216
|
-
colnames = pd.MultiIndex.from_arrays([concepts], names=['Predicted'])
|
217
|
-
data.columns = colnames
|
218
|
-
ax = sns.heatmap(data, cmap='Blues', annot=True, annot_kws={"fontsize": 8}, **kws)
|
219
|
-
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, fontsize=6)
|
220
|
-
ax.set_yticklabels(ax.get_yticklabels(), fontsize=6, rotation=0)
|
221
|
-
|
222
|
-
temp = all_dfs.copy()
|
223
|
-
temp.columns = ["_".join(pair) for pair in temp.columns]
|
224
|
-
with sns.plotting_context(font_scale=5.5):
|
225
|
-
g = sns.FacetGrid(
|
226
|
-
temp,
|
227
|
-
col=self.comparator,
|
228
|
-
col_wrap=3,
|
229
|
-
aspect=1,
|
230
|
-
height=3,
|
231
|
-
sharex=False,
|
232
|
-
sharey=False,
|
233
|
-
)
|
234
|
-
cbar_ax = g.figure.add_axes([.92, .3, .02, .4])
|
235
|
-
g = g.map_dataframe(
|
236
|
-
_facet_heatmap, cbar_ax=cbar_ax, vmin=0, vmax=1, cbar=True, square=True)
|
237
|
-
g.set_titles(col_template=str(self.comparator) + ':{col_name}', fontsize=5)
|
238
|
-
if show:
|
239
|
-
plt.show()
|
240
|
-
if save_path:
|
241
|
-
g.savefig(save_path)
|
242
|
-
|
243
|
-
return all_dfs if isinstance(all_dfs, pd.DataFrame) else None
|
244
|
-
|
245
|
-
@staticmethod
|
246
|
-
def _set_default_kwargs(kwargs: dict, var_name: str, value):
|
247
|
-
if var_name not in kwargs:
|
248
|
-
kwargs.update({var_name: value})
|
249
|
-
return kwargs
|
250
|
-
|
251
|
-
@staticmethod
|
252
|
-
def _setup_default_lineplot(df: pd.DataFrame, kwargs: dict):
|
253
|
-
hue_order = df["concept"].unique().tolist()
|
254
|
-
hue_order.remove(MACRO_AVG)
|
255
|
-
hue_order.insert(0, MACRO_AVG)
|
256
|
-
EvalResultCompare._set_default_kwargs(kwargs, "hue_order", hue_order)
|
257
|
-
|
258
|
-
sizes = {}
|
259
|
-
for each in hue_order:
|
260
|
-
s = 1.5
|
261
|
-
if each == MACRO_AVG:
|
262
|
-
s = 4.
|
263
|
-
sizes.update({each: s})
|
264
|
-
EvalResultCompare._set_default_kwargs(kwargs, "sizes", sizes)
|
265
|
-
EvalResultCompare._set_default_kwargs(kwargs, "size", "concept")
|
266
|
-
|
267
|
-
EvalResultCompare._set_default_kwargs(kwargs, "errorbar", None)
|
268
|
-
EvalResultCompare._set_default_kwargs(kwargs, "height", 5)
|
269
|
-
|
270
|
-
return kwargs
|
271
|
-
|
272
|
-
def roc_curve_plot(self,
|
273
|
-
show=True,
|
274
|
-
save_path: str = None,
|
275
|
-
roc_curve_kwargs: dict = {},
|
276
|
-
relplot_kwargs: dict = {}) -> Union[pd.DataFrame, None]:
|
277
|
-
"""Return dataframe of ROC curve
|
278
|
-
Args:
|
279
|
-
show (bool, optional): Show the chart. Defaults to True.
|
280
|
-
save_path (str): path to save rendered chart.
|
281
|
-
pr_curve_kwargs (dict): keyword args of `eval_handler[...].model.roc_curve` method.
|
282
|
-
relplot_kwargs (dict): keyword args of `sns.relplot` except {data,x,y,hue,kind,col}. where x="fpr", y="tpr", hue="concept"
|
283
|
-
Returns:
|
284
|
-
None or pd.Dataframe, If models don't have ROC curve, return None
|
285
|
-
"""
|
286
|
-
sns.set_palette("Paired")
|
287
|
-
outs, comparator = self._loop_eval_handlers("roc_curve", **roc_curve_kwargs)
|
288
|
-
all_dfs = []
|
289
|
-
for _, (df, anchor) in enumerate(zip(outs, comparator)):
|
290
|
-
df[self.comparator] = [anchor for _ in range(len(df))]
|
291
|
-
all_dfs.append(df)
|
292
|
-
|
293
|
-
if all_dfs:
|
294
|
-
all_dfs = pd.concat(all_dfs, axis=0)
|
295
|
-
if save_path or show:
|
296
|
-
relplot_kwargs = self._setup_default_lineplot(all_dfs, relplot_kwargs)
|
297
|
-
g = sns.relplot(
|
298
|
-
data=all_dfs,
|
299
|
-
x="fpr",
|
300
|
-
y="tpr",
|
301
|
-
hue='concept',
|
302
|
-
kind="line",
|
303
|
-
col=self.comparator,
|
304
|
-
**relplot_kwargs)
|
305
|
-
g.set_titles(col_template=str(self.comparator) + ':{col_name}', fontsize=5)
|
306
|
-
if show:
|
307
|
-
plt.show()
|
308
|
-
if save_path:
|
309
|
-
g.savefig(save_path)
|
310
|
-
|
311
|
-
return all_dfs if isinstance(all_dfs, pd.DataFrame) else None
|
312
|
-
|
313
|
-
def pr_plot(self,
|
314
|
-
show=True,
|
315
|
-
save_path: str = None,
|
316
|
-
pr_curve_kwargs: dict = {},
|
317
|
-
relplot_kwargs: dict = {}) -> Union[pd.DataFrame, None]:
|
318
|
-
"""Return dataframe of PR curve
|
319
|
-
Args:
|
320
|
-
show (bool, optional): Show the chart. Defaults to True.
|
321
|
-
save_path (str): path to save rendered chart.
|
322
|
-
pr_curve_kwargs (dict): keyword args of `eval_handler[...].model.pr_curve` method.
|
323
|
-
relplot_kwargs (dict): keyword args of `sns.relplot` except {data,x,y,hue,kind,col} where x="recall", y="precision", hue="concept"
|
324
|
-
Returns:
|
325
|
-
None or pd.Dataframe, If models don't have PR curve, return None
|
326
|
-
"""
|
327
|
-
sns.set_palette("Paired")
|
328
|
-
outs, comparator = self._loop_eval_handlers("pr_curve", **pr_curve_kwargs)
|
329
|
-
all_dfs = []
|
330
|
-
for _, (df, anchor) in enumerate(zip(outs, comparator)):
|
331
|
-
df[self.comparator] = [anchor for _ in range(len(df))]
|
332
|
-
all_dfs.append(df)
|
333
|
-
|
334
|
-
if all_dfs:
|
335
|
-
all_dfs = pd.concat(all_dfs, axis=0)
|
336
|
-
if save_path or show:
|
337
|
-
relplot_kwargs = self._setup_default_lineplot(all_dfs, relplot_kwargs)
|
338
|
-
g = sns.relplot(
|
339
|
-
data=all_dfs,
|
340
|
-
x="recall",
|
341
|
-
y="precision",
|
342
|
-
hue='concept',
|
343
|
-
kind="line",
|
344
|
-
col=self.comparator,
|
345
|
-
**relplot_kwargs)
|
346
|
-
g.set_titles(col_template=str(self.comparator) + ':{col_name}', fontsize=5)
|
347
|
-
if show:
|
348
|
-
plt.show()
|
349
|
-
if save_path:
|
350
|
-
g.savefig(save_path)
|
351
|
-
|
352
|
-
return all_dfs if isinstance(all_dfs, pd.DataFrame) else None
|
353
|
-
|
354
|
-
def all(
|
355
|
-
self,
|
356
|
-
output_folder: str,
|
357
|
-
confidence_threshold: float = 0.5,
|
358
|
-
iou_threshold: float = 0.5,
|
359
|
-
overwrite: bool = False,
|
360
|
-
metric_kwargs: dict = {},
|
361
|
-
pr_plot_kwargs: dict = {},
|
362
|
-
roc_plot_kwargs: dict = {},
|
363
|
-
):
|
364
|
-
"""Run all comparison methods one by one:
|
365
|
-
- detailed_summary
|
366
|
-
- pr_curve (if applicable)
|
367
|
-
- pr_plot
|
368
|
-
- confusion_matrix (if applicable)
|
369
|
-
And save to output_folder
|
370
|
-
|
371
|
-
Args:
|
372
|
-
output_folder (str): path to output
|
373
|
-
confidence_threshold (float): confidence threshold, applicable for classification and detection. Default is 0.5.
|
374
|
-
iou_threshold (float): iou threshold, support in range(0.5, 1., step=0.1) applicable for detection.
|
375
|
-
overwrite (bool): overwrite result of output_folder.
|
376
|
-
metric_kwargs (dict): keyword args for `eval_handler[...].model.{method}`, except for {confidence_threshold, iou_threshold}.
|
377
|
-
roc_plot_kwargs (dict): for relplot_kwargs of `roc_curve_plot` method.
|
378
|
-
pr_plot_kwargs (dict): for relplot_kwargs of `pr_plot` method.
|
379
|
-
"""
|
380
|
-
eval_type = get_eval_type(self.model_type)
|
381
|
-
area = metric_kwargs.pop("area", "all")
|
382
|
-
bypass_const = metric_kwargs.pop("bypass_const", False)
|
383
|
-
|
384
|
-
fname = f"conf-{confidence_threshold}"
|
385
|
-
if eval_type == EvalType.DETECTION:
|
386
|
-
fname = f"{fname}_iou-{iou_threshold}_area-{area}"
|
387
|
-
|
388
|
-
def join_root(*args):
|
389
|
-
return os.path.join(output_folder, *args)
|
390
|
-
|
391
|
-
output_folder = join_root(fname)
|
392
|
-
if os.path.exists(output_folder) and not overwrite:
|
393
|
-
raise RuntimeError(f"{output_folder} exists. If you want to overwrite, set `overwrite=True`")
|
394
|
-
|
395
|
-
os.makedirs(output_folder, exist_ok=True)
|
396
|
-
|
397
|
-
logger.info("Making summary tables...")
|
398
|
-
dfs = self.detailed_summary(
|
399
|
-
confidence_threshold=confidence_threshold,
|
400
|
-
iou_threshold=iou_threshold,
|
401
|
-
area=area,
|
402
|
-
bypass_const=bypass_const)
|
403
|
-
if dfs is not None:
|
404
|
-
concept_df, total_df = dfs
|
405
|
-
concept_df.to_csv(join_root("concepts_summary.csv"))
|
406
|
-
total_df.to_csv(join_root("total_summary.csv"))
|
407
|
-
|
408
|
-
curve_metric_kwargs = dict(
|
409
|
-
confidence_threshold=confidence_threshold, iou_threshold=iou_threshold)
|
410
|
-
curve_metric_kwargs.update(metric_kwargs)
|
411
|
-
|
412
|
-
self.roc_curve_plot(
|
413
|
-
show=False,
|
414
|
-
save_path=join_root("roc.jpg"),
|
415
|
-
roc_curve_kwargs=curve_metric_kwargs,
|
416
|
-
relplot_kwargs=roc_plot_kwargs)
|
417
|
-
|
418
|
-
self.pr_plot(
|
419
|
-
show=False,
|
420
|
-
save_path=join_root("pr.jpg"),
|
421
|
-
pr_curve_kwargs=curve_metric_kwargs,
|
422
|
-
relplot_kwargs=pr_plot_kwargs)
|
423
|
-
|
424
|
-
self.confusion_matrix(
|
425
|
-
show=False, save_path=join_root("confusion_matrix.jpg"), cm_kwargs=curve_metric_kwargs)
|
426
|
-
|
427
|
-
logger.info(f"Done. Your outputs are saved at {output_folder}")
|
3
|
+
__all__ = ["EvalResultCompare"]
|