clarifai 10.2.1__py3-none-any.whl → 10.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
clarifai/client/user.py CHANGED
@@ -7,7 +7,6 @@ from google.protobuf.json_format import MessageToDict
7
7
  from clarifai.client.app import App
8
8
  from clarifai.client.base import BaseClient
9
9
  from clarifai.client.lister import Lister
10
- from clarifai.client.runner import Runner
11
10
  from clarifai.errors import UserError
12
11
  from clarifai.utils.logging import get_logger
13
12
 
@@ -20,6 +19,7 @@ class User(Lister, BaseClient):
20
19
  base_url: str = "https://api.clarifai.com",
21
20
  pat: str = None,
22
21
  token: str = None,
22
+ root_certificates_path: str = None,
23
23
  **kwargs):
24
24
  """Initializes an User object.
25
25
 
@@ -28,12 +28,20 @@ class User(Lister, BaseClient):
28
28
  base_url (str): Base API url. Default "https://api.clarifai.com"
29
29
  pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
30
30
  token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
31
+ root_certificates_path (str): Path to the SSL root certificates file, used to establish secure gRPC connections.
31
32
  **kwargs: Additional keyword arguments to be passed to the User.
32
33
  """
33
34
  self.kwargs = {**kwargs, 'id': user_id}
34
35
  self.user_info = resources_pb2.User(**self.kwargs)
35
36
  self.logger = get_logger(logger_level="INFO", name=__name__)
36
- BaseClient.__init__(self, user_id=self.id, app_id="", base=base_url, pat=pat, token=token)
37
+ BaseClient.__init__(
38
+ self,
39
+ user_id=self.id,
40
+ app_id="",
41
+ base=base_url,
42
+ pat=pat,
43
+ token=token,
44
+ root_certificates_path=root_certificates_path)
37
45
  Lister.__init__(self)
38
46
 
39
47
  def list_apps(self, filter_by: Dict[str, Any] = {}, page_no: int = None,
@@ -69,7 +77,7 @@ class User(Lister, BaseClient):
69
77
  **app_info) #(base_url=self.base, pat=self.pat, token=self.token, **app_info)
70
78
 
71
79
  def list_runners(self, filter_by: Dict[str, Any] = {}, page_no: int = None,
72
- per_page: int = None) -> Generator[Runner, None, None]:
80
+ per_page: int = None) -> Generator[dict, None, None]:
73
81
  """List all runners for the user
74
82
 
75
83
  Args:
@@ -78,7 +86,7 @@ class User(Lister, BaseClient):
78
86
  per_page (int): The number of items per page.
79
87
 
80
88
  Yields:
81
- Runner: Runner objects for the runners.
89
+ Dict: Dictionaries containing information about the runners.
82
90
 
83
91
  Example:
84
92
  >>> from clarifai.client.user import User
@@ -98,8 +106,7 @@ class User(Lister, BaseClient):
98
106
  page_no=page_no)
99
107
 
100
108
  for runner_info in all_runners_info:
101
- yield Runner.from_auth_helper(
102
- auth=self.auth_helper, check_runner_exists=False, **runner_info)
109
+ yield dict(auth=self.auth_helper, check_runner_exists=False, **runner_info)
103
110
 
104
111
  def create_app(self, app_id: str, base_workflow: str = 'Empty', **kwargs) -> App:
105
112
  """Creates an app for the user.
@@ -127,7 +134,7 @@ class User(Lister, BaseClient):
127
134
  self.logger.info("\nApp created\n%s", response.status)
128
135
  return App.from_auth_helper(auth=self.auth_helper, app_id=app_id)
129
136
 
130
- def create_runner(self, runner_id: str, labels: List[str], description: str) -> Runner:
137
+ def create_runner(self, runner_id: str, labels: List[str], description: str) -> dict:
131
138
  """Create a runner
132
139
 
133
140
  Args:
@@ -136,13 +143,14 @@ class User(Lister, BaseClient):
136
143
  description (str): Description of Runner
137
144
 
138
145
  Returns:
139
- Runner: A runner object for the specified Runner ID
146
+ Dict: A dictionary containing information about the specified Runner ID.
140
147
 
141
148
  Example:
142
149
  >>> from clarifai.client.user import User
143
150
  >>> client = User(user_id="user_id")
144
- >>> runner = client.create_runner(runner_id="runner_id", labels=["label to link runner"], description="laptop runner")
151
+ >>> runner_info = client.create_runner(runner_id="runner_id", labels=["label to link runner"], description="laptop runner")
145
152
  """
153
+
146
154
  if not isinstance(labels, List):
147
155
  raise UserError("Labels must be a List of strings")
148
156
 
@@ -155,7 +163,7 @@ class User(Lister, BaseClient):
155
163
  raise Exception(response.status)
156
164
  self.logger.info("\nRunner created\n%s", response.status)
157
165
 
158
- return Runner.from_auth_helper(
166
+ return dict(
159
167
  auth=self.auth_helper,
160
168
  runner_id=runner_id,
161
169
  user_id=self.id,
@@ -186,19 +194,19 @@ class User(Lister, BaseClient):
186
194
  kwargs['user_id'] = self.id
187
195
  return App.from_auth_helper(auth=self.auth_helper, app_id=app_id, **kwargs)
188
196
 
189
- def runner(self, runner_id: str) -> Runner:
197
+ def runner(self, runner_id: str) -> dict:
190
198
  """Returns a Runner object if exists.
191
199
 
192
200
  Args:
193
201
  runner_id (str): The runner ID to interact with
194
202
 
195
203
  Returns:
196
- Runner: A Runner object for the existing runner ID.
204
+ Dict: A dictionary containing information about the existing runner ID.
197
205
 
198
206
  Example:
199
207
  >>> from clarifai.client.user import User
200
208
  >>> client = User(user_id="user_id")
201
- >>> runner = client.runner(runner_id="runner_id")
209
+ >>> runner_info = client.runner(runner_id="runner_id")
202
210
  """
203
211
  request = service_pb2.GetRunnerRequest(user_app_id=self.user_app_id, runner_id=runner_id)
204
212
  response = self._grpc_request(self.STUB.GetRunner, request)
@@ -212,7 +220,7 @@ class User(Lister, BaseClient):
212
220
  kwargs = self.process_response_keys(dict_response[list(dict_response.keys())[1]],
213
221
  list(dict_response.keys())[1])
214
222
 
215
- return Runner.from_auth_helper(self.auth_helper, check_runner_exists=False, **kwargs)
223
+ return dict(self.auth_helper, check_runner_exists=False, **kwargs)
216
224
 
217
225
  def delete_app(self, app_id: str) -> None:
218
226
  """Deletes an app for the user.
@@ -28,6 +28,7 @@ class Workflow(Lister, BaseClient):
28
28
  base_url: str = "https://api.clarifai.com",
29
29
  pat: str = None,
30
30
  token: str = None,
31
+ root_certificates_path: str = None,
31
32
  **kwargs):
32
33
  """Initializes a Workflow object.
33
34
 
@@ -43,6 +44,7 @@ class Workflow(Lister, BaseClient):
43
44
  base_url (str): Base API url. Default "https://api.clarifai.com"
44
45
  pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
45
46
  token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
47
+ root_certificates_path (str): Path to the SSL root certificates file, used to establish secure gRPC connections.
46
48
  **kwargs: Additional keyword arguments to be passed to the Workflow.
47
49
  """
48
50
  if url and workflow_id:
@@ -59,7 +61,13 @@ class Workflow(Lister, BaseClient):
59
61
  self.workflow_info = resources_pb2.Workflow(**self.kwargs)
60
62
  self.logger = get_logger(logger_level="INFO", name=__name__)
61
63
  BaseClient.__init__(
62
- self, user_id=self.user_id, app_id=self.app_id, base=base_url, pat=pat, token=token)
64
+ self,
65
+ user_id=self.user_id,
66
+ app_id=self.app_id,
67
+ base=base_url,
68
+ pat=pat,
69
+ token=token,
70
+ root_certificates_path=root_certificates_path)
63
71
  Lister.__init__(self)
64
72
 
65
73
  def predict(self, inputs: List[Input], workflow_state_id: str = None):
@@ -0,0 +1 @@
1
+ MAX_UPLOAD_BATCH_SIZE = 128
@@ -1,427 +1,3 @@
1
- import os
2
- from enum import Enum
3
- from typing import List, Tuple, Union
1
+ from .main import EvalResultCompare
4
2
 
5
- from clarifai.client.dataset import Dataset
6
- from clarifai.client.model import Model
7
-
8
- from .helpers import (MACRO_AVG, EvalType, _BaseEvalResultHandler, get_eval_type,
9
- make_handler_by_type)
10
-
11
- try:
12
- import seaborn as sns
13
- except ImportError:
14
- raise ImportError("Can not import seaborn. Please run `pip install seaborn` to install it")
15
-
16
- try:
17
- import matplotlib.pyplot as plt
18
- except ImportError:
19
- raise ImportError("Can not import matplotlib. Please run `pip install matplotlib` to install it")
20
-
21
- try:
22
- import pandas as pd
23
- except ImportError:
24
- raise ImportError("Can not import pandas. Please run `pip install pandas` to install it")
25
-
26
- try:
27
- from loguru import logger
28
- except ImportError:
29
- from ..logging import get_logger
30
- logger = get_logger(logger_level="INFO", name=__name__)
31
-
32
- __all__ = ['EvalResultCompare']
33
-
34
-
35
- class CompareMode(Enum):
36
- MANY_MODELS_TO_ONE_DATA = 0
37
- ONE_MODEL_TO_MANY_DATA = 1
38
-
39
-
40
- class EvalResultCompare:
41
- """Compare evaluation result of models against datasets.
42
- Note: The module will pick latest result on the datasets.
43
- and models must be same model type
44
-
45
- Args:
46
- ---
47
- models (Union[List[Model], List[str]]): List of Model or urls of models.
48
- datasets (Union[Dataset, List[Dataset], str, List[str]]): A single or List of Url or Dataset
49
- attempt_evaluate (bool): Evaluate when model is not evaluated with the datasets.
50
- auth_kwargs (dict): Additional auth keyword arguments to be passed to the Dataset and Model if using url(s)
51
- """
52
-
53
- def __init__(self,
54
- models: Union[List[Model], List[str]],
55
- datasets: Union[Dataset, List[Dataset], str, List[str]],
56
- attempt_evaluate: bool = False,
57
- auth_kwargs: dict = {}):
58
- assert isinstance(models, list), ValueError("Expected list")
59
-
60
- if len(models) > 1:
61
- self.mode = CompareMode.MANY_MODELS_TO_ONE_DATA
62
- self.comparator = "Model"
63
- assert isinstance(datasets, Dataset) or (
64
- isinstance(datasets, list) and len(datasets) == 1
65
- ), f"When comparing multiple models, must provide only one `datasets`. However got {datasets}"
66
- else:
67
- self.mode = CompareMode.ONE_MODEL_TO_MANY_DATA
68
- self.comparator = "Dataset"
69
-
70
- # validate models
71
- if all(map(lambda x: isinstance(x, str), models)):
72
- models = [Model(each, **auth_kwargs) for each in models]
73
- elif not all(map(lambda x: isinstance(x, Model), models)):
74
- raise ValueError(
75
- f"Expected all models are list of string or list of Model, got {[type(each) for each in models]}"
76
- )
77
- # validate datasets
78
- if not isinstance(datasets, list):
79
- datasets = [
80
- datasets,
81
- ]
82
- if all(map(lambda x: isinstance(x, str), datasets)):
83
- datasets = [Dataset(each, **auth_kwargs) for each in datasets]
84
- elif not all(map(lambda x: isinstance(x, Dataset), datasets)):
85
- raise ValueError(
86
- f"Expected datasets must be str, list of string or Dataset, list of Dataset, got {[type(each) for each in datasets]}"
87
- )
88
- # Validate models vs datasets together
89
- self._eval_handlers: List[_BaseEvalResultHandler] = []
90
- self.model_type = None
91
- logger.info("Initializing models...")
92
- for model in models:
93
- model.load_info()
94
- model_type = model.model_info.model_type_id
95
- if not self.model_type:
96
- self.model_type = model_type
97
- else:
98
- assert self.model_type == model_type, f"Can not compare when model types are different, {self.model_type} != {model_type}"
99
- m = make_handler_by_type(model_type)(model=model)
100
- logger.info(f"* {m.get_model_name(pretify=True)}")
101
- m.find_eval_id(datasets=datasets, attempt_evaluate=attempt_evaluate)
102
- self._eval_handlers.append(m)
103
-
104
- @property
105
- def eval_handlers(self):
106
- return self._eval_handlers
107
-
108
- def _loop_eval_handlers(self, func_name: str, **kwargs) -> Tuple[list, list]:
109
- """ Run methods of `eval_handlers[...].model`
110
-
111
- Args:
112
- func_name (str): method name, see `_BaseEvalResultHandler` child classes
113
- kwargs: keyword arguments of the method
114
-
115
- Return:
116
- tuple:
117
- - list of outputs
118
- - list of comparator names
119
-
120
- """
121
- outs = []
122
- comparators = []
123
- logger.info(f'Running `{func_name}`')
124
- for _, each in enumerate(self.eval_handlers):
125
- for ds_index, _ in enumerate(each.eval_data):
126
- func = eval(f'each.{func_name}')
127
- out = func(index=ds_index, **kwargs)
128
-
129
- if self.mode == CompareMode.MANY_MODELS_TO_ONE_DATA:
130
- name = each.get_model_name(pretify=True)
131
- else:
132
- name = each.get_dataset_name_by_index(ds_index, pretify=True)
133
- if out is None:
134
- logger.warning(f'{self.comparator}:{name} does not have valid data for `{func_name}`')
135
- continue
136
- comparators.append(name)
137
- outs.append(out)
138
-
139
- # remove app_id if models a
140
- if self.mode == CompareMode.MANY_MODELS_TO_ONE_DATA:
141
- apps = set([comp.split('/')[0] for comp in comparators])
142
- if len(apps) == 1:
143
- comparators = ['/'.join(comp.split('/')[1:]) for comp in comparators]
144
-
145
- if not outs:
146
- logger.warning(f'Model type {self.model_type} does not support `{func_name}`')
147
-
148
- return outs, comparators
149
-
150
- def detailed_summary(self,
151
- confidence_threshold: float = .5,
152
- iou_threshold: float = .5,
153
- area: str = "all",
154
- bypass_const=False) -> Union[Tuple[pd.DataFrame, pd.DataFrame], None]:
155
- """
156
- Retrieve and compute popular metrics of model.
157
-
158
- Args:
159
- confidence_threshold (float): confidence threshold, applicable for classification and detection. Default is 0.5
160
- iou_threshold (float): iou threshold, support in range(0.5, 1., step=0.1) applicable for detection
161
- area (float): size of area, support {all, small, medium}, applicable for detection
162
-
163
- Return:
164
- None or tuple of dataframe: df summary per concept and total concepts
165
-
166
- """
167
- df = []
168
- total = []
169
- # loop over all eval_handlers/dataset and call its method
170
- outs, comparators = self._loop_eval_handlers(
171
- 'detailed_summary',
172
- confidence_threshold=confidence_threshold,
173
- iou_threshold=iou_threshold,
174
- area=area,
175
- bypass_const=bypass_const)
176
- for indx, out in enumerate(outs):
177
- _df, _total = out
178
- _df[self.comparator] = [comparators[indx] for _ in range(len(_df))]
179
- _total['Concept'].replace(
180
- to_replace=['Total'], value=f'{self.comparator}:{comparators[indx]}', inplace=True)
181
- _total.rename({'Concept': 'Total Concept'}, axis=1, inplace=True)
182
- df.append(_df)
183
- total.append(_total)
184
-
185
- if df:
186
- df = pd.concat(df, axis=0)
187
- total = pd.concat(total, axis=0)
188
- return df, total
189
- else:
190
- return None
191
-
192
- def confusion_matrix(self, show=True, save_path: str = None,
193
- cm_kwargs: dict = {}) -> Union[pd.DataFrame, None]:
194
- """Return dataframe of confusion matrix
195
- Args:
196
- show (bool, optional): Show the chart. Defaults to True.
197
- save_path (str): path to save rendered chart.
198
- cm_kwargs (dict): keyword args of `eval_handler[...].model.cm_kwargs` method.
199
- Returns:
200
- None or pd.Dataframe, If models don't have confusion matrix, return None
201
- """
202
- outs, comparators = self._loop_eval_handlers("confusion_matrix", **cm_kwargs)
203
- all_dfs = []
204
- for _, (df, anchor) in enumerate(zip(outs, comparators)):
205
- df[self.comparator] = [anchor for _ in range(len(df))]
206
- all_dfs.append(df)
207
-
208
- if all_dfs:
209
- all_dfs = pd.concat(all_dfs, axis=0)
210
- if save_path or show:
211
-
212
- def _facet_heatmap(data, **kws):
213
- data = data.dropna(axis=1)
214
- data = data.drop(self.comparator, axis=1)
215
- concepts = data.columns
216
- colnames = pd.MultiIndex.from_arrays([concepts], names=['Predicted'])
217
- data.columns = colnames
218
- ax = sns.heatmap(data, cmap='Blues', annot=True, annot_kws={"fontsize": 8}, **kws)
219
- ax.set_xticklabels(ax.get_xticklabels(), rotation=45, fontsize=6)
220
- ax.set_yticklabels(ax.get_yticklabels(), fontsize=6, rotation=0)
221
-
222
- temp = all_dfs.copy()
223
- temp.columns = ["_".join(pair) for pair in temp.columns]
224
- with sns.plotting_context(font_scale=5.5):
225
- g = sns.FacetGrid(
226
- temp,
227
- col=self.comparator,
228
- col_wrap=3,
229
- aspect=1,
230
- height=3,
231
- sharex=False,
232
- sharey=False,
233
- )
234
- cbar_ax = g.figure.add_axes([.92, .3, .02, .4])
235
- g = g.map_dataframe(
236
- _facet_heatmap, cbar_ax=cbar_ax, vmin=0, vmax=1, cbar=True, square=True)
237
- g.set_titles(col_template=str(self.comparator) + ':{col_name}', fontsize=5)
238
- if show:
239
- plt.show()
240
- if save_path:
241
- g.savefig(save_path)
242
-
243
- return all_dfs if isinstance(all_dfs, pd.DataFrame) else None
244
-
245
- @staticmethod
246
- def _set_default_kwargs(kwargs: dict, var_name: str, value):
247
- if var_name not in kwargs:
248
- kwargs.update({var_name: value})
249
- return kwargs
250
-
251
- @staticmethod
252
- def _setup_default_lineplot(df: pd.DataFrame, kwargs: dict):
253
- hue_order = df["concept"].unique().tolist()
254
- hue_order.remove(MACRO_AVG)
255
- hue_order.insert(0, MACRO_AVG)
256
- EvalResultCompare._set_default_kwargs(kwargs, "hue_order", hue_order)
257
-
258
- sizes = {}
259
- for each in hue_order:
260
- s = 1.5
261
- if each == MACRO_AVG:
262
- s = 4.
263
- sizes.update({each: s})
264
- EvalResultCompare._set_default_kwargs(kwargs, "sizes", sizes)
265
- EvalResultCompare._set_default_kwargs(kwargs, "size", "concept")
266
-
267
- EvalResultCompare._set_default_kwargs(kwargs, "errorbar", None)
268
- EvalResultCompare._set_default_kwargs(kwargs, "height", 5)
269
-
270
- return kwargs
271
-
272
- def roc_curve_plot(self,
273
- show=True,
274
- save_path: str = None,
275
- roc_curve_kwargs: dict = {},
276
- relplot_kwargs: dict = {}) -> Union[pd.DataFrame, None]:
277
- """Return dataframe of ROC curve
278
- Args:
279
- show (bool, optional): Show the chart. Defaults to True.
280
- save_path (str): path to save rendered chart.
281
- pr_curve_kwargs (dict): keyword args of `eval_handler[...].model.roc_curve` method.
282
- relplot_kwargs (dict): keyword args of `sns.relplot` except {data,x,y,hue,kind,col}. where x="fpr", y="tpr", hue="concept"
283
- Returns:
284
- None or pd.Dataframe, If models don't have ROC curve, return None
285
- """
286
- sns.set_palette("Paired")
287
- outs, comparator = self._loop_eval_handlers("roc_curve", **roc_curve_kwargs)
288
- all_dfs = []
289
- for _, (df, anchor) in enumerate(zip(outs, comparator)):
290
- df[self.comparator] = [anchor for _ in range(len(df))]
291
- all_dfs.append(df)
292
-
293
- if all_dfs:
294
- all_dfs = pd.concat(all_dfs, axis=0)
295
- if save_path or show:
296
- relplot_kwargs = self._setup_default_lineplot(all_dfs, relplot_kwargs)
297
- g = sns.relplot(
298
- data=all_dfs,
299
- x="fpr",
300
- y="tpr",
301
- hue='concept',
302
- kind="line",
303
- col=self.comparator,
304
- **relplot_kwargs)
305
- g.set_titles(col_template=str(self.comparator) + ':{col_name}', fontsize=5)
306
- if show:
307
- plt.show()
308
- if save_path:
309
- g.savefig(save_path)
310
-
311
- return all_dfs if isinstance(all_dfs, pd.DataFrame) else None
312
-
313
- def pr_plot(self,
314
- show=True,
315
- save_path: str = None,
316
- pr_curve_kwargs: dict = {},
317
- relplot_kwargs: dict = {}) -> Union[pd.DataFrame, None]:
318
- """Return dataframe of PR curve
319
- Args:
320
- show (bool, optional): Show the chart. Defaults to True.
321
- save_path (str): path to save rendered chart.
322
- pr_curve_kwargs (dict): keyword args of `eval_handler[...].model.pr_curve` method.
323
- relplot_kwargs (dict): keyword args of `sns.relplot` except {data,x,y,hue,kind,col} where x="recall", y="precision", hue="concept"
324
- Returns:
325
- None or pd.Dataframe, If models don't have PR curve, return None
326
- """
327
- sns.set_palette("Paired")
328
- outs, comparator = self._loop_eval_handlers("pr_curve", **pr_curve_kwargs)
329
- all_dfs = []
330
- for _, (df, anchor) in enumerate(zip(outs, comparator)):
331
- df[self.comparator] = [anchor for _ in range(len(df))]
332
- all_dfs.append(df)
333
-
334
- if all_dfs:
335
- all_dfs = pd.concat(all_dfs, axis=0)
336
- if save_path or show:
337
- relplot_kwargs = self._setup_default_lineplot(all_dfs, relplot_kwargs)
338
- g = sns.relplot(
339
- data=all_dfs,
340
- x="recall",
341
- y="precision",
342
- hue='concept',
343
- kind="line",
344
- col=self.comparator,
345
- **relplot_kwargs)
346
- g.set_titles(col_template=str(self.comparator) + ':{col_name}', fontsize=5)
347
- if show:
348
- plt.show()
349
- if save_path:
350
- g.savefig(save_path)
351
-
352
- return all_dfs if isinstance(all_dfs, pd.DataFrame) else None
353
-
354
- def all(
355
- self,
356
- output_folder: str,
357
- confidence_threshold: float = 0.5,
358
- iou_threshold: float = 0.5,
359
- overwrite: bool = False,
360
- metric_kwargs: dict = {},
361
- pr_plot_kwargs: dict = {},
362
- roc_plot_kwargs: dict = {},
363
- ):
364
- """Run all comparison methods one by one:
365
- - detailed_summary
366
- - pr_curve (if applicable)
367
- - pr_plot
368
- - confusion_matrix (if applicable)
369
- And save to output_folder
370
-
371
- Args:
372
- output_folder (str): path to output
373
- confidence_threshold (float): confidence threshold, applicable for classification and detection. Default is 0.5.
374
- iou_threshold (float): iou threshold, support in range(0.5, 1., step=0.1) applicable for detection.
375
- overwrite (bool): overwrite result of output_folder.
376
- metric_kwargs (dict): keyword args for `eval_handler[...].model.{method}`, except for {confidence_threshold, iou_threshold}.
377
- roc_plot_kwargs (dict): for relplot_kwargs of `roc_curve_plot` method.
378
- pr_plot_kwargs (dict): for relplot_kwargs of `pr_plot` method.
379
- """
380
- eval_type = get_eval_type(self.model_type)
381
- area = metric_kwargs.pop("area", "all")
382
- bypass_const = metric_kwargs.pop("bypass_const", False)
383
-
384
- fname = f"conf-{confidence_threshold}"
385
- if eval_type == EvalType.DETECTION:
386
- fname = f"{fname}_iou-{iou_threshold}_area-{area}"
387
-
388
- def join_root(*args):
389
- return os.path.join(output_folder, *args)
390
-
391
- output_folder = join_root(fname)
392
- if os.path.exists(output_folder) and not overwrite:
393
- raise RuntimeError(f"{output_folder} exists. If you want to overwrite, set `overwrite=True`")
394
-
395
- os.makedirs(output_folder, exist_ok=True)
396
-
397
- logger.info("Making summary tables...")
398
- dfs = self.detailed_summary(
399
- confidence_threshold=confidence_threshold,
400
- iou_threshold=iou_threshold,
401
- area=area,
402
- bypass_const=bypass_const)
403
- if dfs is not None:
404
- concept_df, total_df = dfs
405
- concept_df.to_csv(join_root("concepts_summary.csv"))
406
- total_df.to_csv(join_root("total_summary.csv"))
407
-
408
- curve_metric_kwargs = dict(
409
- confidence_threshold=confidence_threshold, iou_threshold=iou_threshold)
410
- curve_metric_kwargs.update(metric_kwargs)
411
-
412
- self.roc_curve_plot(
413
- show=False,
414
- save_path=join_root("roc.jpg"),
415
- roc_curve_kwargs=curve_metric_kwargs,
416
- relplot_kwargs=roc_plot_kwargs)
417
-
418
- self.pr_plot(
419
- show=False,
420
- save_path=join_root("pr.jpg"),
421
- pr_curve_kwargs=curve_metric_kwargs,
422
- relplot_kwargs=pr_plot_kwargs)
423
-
424
- self.confusion_matrix(
425
- show=False, save_path=join_root("confusion_matrix.jpg"), cm_kwargs=curve_metric_kwargs)
426
-
427
- logger.info(f"Done. Your outputs are saved at {output_folder}")
3
+ __all__ = ["EvalResultCompare"]