clarifai 10.2.1__py3-none-any.whl → 10.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- clarifai/client/app.py +15 -7
- clarifai/client/auth/helper.py +12 -2
- clarifai/client/base.py +14 -4
- clarifai/client/dataset.py +15 -2
- clarifai/client/input.py +14 -1
- clarifai/client/model.py +198 -19
- clarifai/client/module.py +9 -1
- clarifai/client/search.py +10 -2
- clarifai/client/user.py +22 -14
- clarifai/client/workflow.py +9 -1
- clarifai/constants/input.py +1 -0
- clarifai/utils/evaluation/__init__.py +2 -426
- clarifai/utils/evaluation/main.py +426 -0
- clarifai/utils/evaluation/testset_annotation_parser.py +150 -0
- clarifai/versions.py +1 -1
- {clarifai-10.2.1.dist-info → clarifai-10.3.0.dist-info}/METADATA +12 -12
- {clarifai-10.2.1.dist-info → clarifai-10.3.0.dist-info}/RECORD +21 -22
- {clarifai-10.2.1.dist-info → clarifai-10.3.0.dist-info}/WHEEL +1 -1
- clarifai/client/runner.py +0 -234
- clarifai/runners/__init__.py +0 -0
- clarifai/runners/example.py +0 -40
- clarifai/runners/example_llama2.py +0 -81
- {clarifai-10.2.1.dist-info → clarifai-10.3.0.dist-info}/LICENSE +0 -0
- {clarifai-10.2.1.dist-info → clarifai-10.3.0.dist-info}/entry_points.txt +0 -0
- {clarifai-10.2.1.dist-info → clarifai-10.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,426 @@
|
|
1
|
+
import os
|
2
|
+
from enum import Enum
|
3
|
+
from typing import List, Tuple, Union
|
4
|
+
|
5
|
+
from clarifai.client.dataset import Dataset
|
6
|
+
from clarifai.client.model import Model
|
7
|
+
|
8
|
+
from .helpers import (MACRO_AVG, EvalType, _BaseEvalResultHandler, get_eval_type,
|
9
|
+
make_handler_by_type)
|
10
|
+
|
11
|
+
try:
|
12
|
+
import seaborn as sns
|
13
|
+
except ImportError:
|
14
|
+
raise ImportError("Can not import seaborn. Please run `pip install seaborn` to install it")
|
15
|
+
|
16
|
+
try:
|
17
|
+
import matplotlib.pyplot as plt
|
18
|
+
except ImportError:
|
19
|
+
raise ImportError("Can not import matplotlib. Please run `pip install matplotlib` to install it")
|
20
|
+
|
21
|
+
try:
|
22
|
+
import pandas as pd
|
23
|
+
except ImportError:
|
24
|
+
raise ImportError("Can not import pandas. Please run `pip install pandas` to install it")
|
25
|
+
|
26
|
+
try:
|
27
|
+
from loguru import logger
|
28
|
+
except ImportError:
|
29
|
+
from ..logging import get_logger
|
30
|
+
logger = get_logger(logger_level="INFO", name=__name__)
|
31
|
+
|
32
|
+
__all__ = ['EvalResultCompare']
|
33
|
+
|
34
|
+
|
35
|
+
class CompareMode(Enum):
|
36
|
+
MANY_MODELS_TO_ONE_DATA = 0
|
37
|
+
ONE_MODEL_TO_MANY_DATA = 1
|
38
|
+
|
39
|
+
|
40
|
+
class EvalResultCompare:
|
41
|
+
"""Compare evaluation result of models against datasets.
|
42
|
+
Note: The module will pick latest result on the datasets.
|
43
|
+
and models must be same model type
|
44
|
+
|
45
|
+
Args:
|
46
|
+
---
|
47
|
+
models (Union[List[Model], List[str]]): List of Model or urls of models.
|
48
|
+
datasets (Union[Dataset, List[Dataset], str, List[str]]): A single or List of Url or Dataset
|
49
|
+
attempt_evaluate (bool): Evaluate when model is not evaluated with the datasets.
|
50
|
+
auth_kwargs (dict): Additional auth keyword arguments to be passed to the Dataset and Model if using url(s)
|
51
|
+
"""
|
52
|
+
|
53
|
+
def __init__(self,
|
54
|
+
models: Union[List[Model], List[str]],
|
55
|
+
datasets: Union[Dataset, List[Dataset], str, List[str]],
|
56
|
+
attempt_evaluate: bool = False,
|
57
|
+
auth_kwargs: dict = {}):
|
58
|
+
assert isinstance(models, list), ValueError("Expected list")
|
59
|
+
|
60
|
+
if len(models) > 1:
|
61
|
+
self.mode = CompareMode.MANY_MODELS_TO_ONE_DATA
|
62
|
+
self.comparator = "Model"
|
63
|
+
assert isinstance(datasets, Dataset) or (
|
64
|
+
isinstance(datasets, list) and len(datasets) == 1
|
65
|
+
), f"When comparing multiple models, must provide only one `datasets`. However got {datasets}"
|
66
|
+
else:
|
67
|
+
self.mode = CompareMode.ONE_MODEL_TO_MANY_DATA
|
68
|
+
self.comparator = "Dataset"
|
69
|
+
|
70
|
+
# validate models
|
71
|
+
if all(map(lambda x: isinstance(x, str), models)):
|
72
|
+
models = [Model(each, **auth_kwargs) for each in models]
|
73
|
+
elif not all(map(lambda x: isinstance(x, Model), models)):
|
74
|
+
raise ValueError(
|
75
|
+
f"Expected all models are list of string or list of Model, got {[type(each) for each in models]}"
|
76
|
+
)
|
77
|
+
# validate datasets
|
78
|
+
if not isinstance(datasets, list):
|
79
|
+
datasets = [
|
80
|
+
datasets,
|
81
|
+
]
|
82
|
+
if all(map(lambda x: isinstance(x, str), datasets)):
|
83
|
+
datasets = [Dataset(each, **auth_kwargs) for each in datasets]
|
84
|
+
elif not all(map(lambda x: isinstance(x, Dataset), datasets)):
|
85
|
+
raise ValueError(
|
86
|
+
f"Expected datasets must be str, list of string or Dataset, list of Dataset, got {[type(each) for each in datasets]}"
|
87
|
+
)
|
88
|
+
# Validate models vs datasets together
|
89
|
+
self._eval_handlers: List[_BaseEvalResultHandler] = []
|
90
|
+
self.model_type = None
|
91
|
+
logger.info("Initializing models...")
|
92
|
+
for model in models:
|
93
|
+
model.load_info()
|
94
|
+
model_type = model.model_info.model_type_id
|
95
|
+
if not self.model_type:
|
96
|
+
self.model_type = model_type
|
97
|
+
else:
|
98
|
+
assert self.model_type == model_type, f"Can not compare when model types are different, {self.model_type} != {model_type}"
|
99
|
+
m = make_handler_by_type(model_type)(model=model)
|
100
|
+
logger.info(f"* {m.get_model_name(pretify=True)}")
|
101
|
+
m.find_eval_id(datasets=datasets, attempt_evaluate=attempt_evaluate)
|
102
|
+
self._eval_handlers.append(m)
|
103
|
+
|
104
|
+
@property
|
105
|
+
def eval_handlers(self):
|
106
|
+
return self._eval_handlers
|
107
|
+
|
108
|
+
def _loop_eval_handlers(self, func_name: str, **kwargs) -> Tuple[list, list]:
|
109
|
+
""" Run methods of `eval_handlers[...].model`
|
110
|
+
|
111
|
+
Args:
|
112
|
+
func_name (str): method name, see `_BaseEvalResultHandler` child classes
|
113
|
+
kwargs: keyword arguments of the method
|
114
|
+
|
115
|
+
Return:
|
116
|
+
tuple:
|
117
|
+
- list of outputs
|
118
|
+
- list of comparator names
|
119
|
+
|
120
|
+
"""
|
121
|
+
outs = []
|
122
|
+
comparators = []
|
123
|
+
logger.info(f'Running `{func_name}`')
|
124
|
+
for _, each in enumerate(self.eval_handlers):
|
125
|
+
for ds_index, _ in enumerate(each.eval_data):
|
126
|
+
func = eval(f'each.{func_name}')
|
127
|
+
out = func(index=ds_index, **kwargs)
|
128
|
+
|
129
|
+
if self.mode == CompareMode.MANY_MODELS_TO_ONE_DATA:
|
130
|
+
name = each.get_model_name(pretify=True)
|
131
|
+
else:
|
132
|
+
name = each.get_dataset_name_by_index(ds_index, pretify=True)
|
133
|
+
if out is None:
|
134
|
+
logger.warning(f'{self.comparator}:{name} does not have valid data for `{func_name}`')
|
135
|
+
continue
|
136
|
+
comparators.append(name)
|
137
|
+
outs.append(out)
|
138
|
+
|
139
|
+
if self.mode == CompareMode.MANY_MODELS_TO_ONE_DATA:
|
140
|
+
apps = set([comp.split('/')[0] for comp in comparators])
|
141
|
+
if len(apps) == 1:
|
142
|
+
comparators = ['/'.join(comp.split('/')[1:]) for comp in comparators]
|
143
|
+
|
144
|
+
if not outs:
|
145
|
+
logger.warning(f'Model type {self.model_type} does not support `{func_name}`')
|
146
|
+
|
147
|
+
return outs, comparators
|
148
|
+
|
149
|
+
def detailed_summary(self,
|
150
|
+
confidence_threshold: float = .5,
|
151
|
+
iou_threshold: float = .5,
|
152
|
+
area: str = "all",
|
153
|
+
bypass_const=False) -> Union[Tuple[pd.DataFrame, pd.DataFrame], None]:
|
154
|
+
"""
|
155
|
+
Retrieve and compute popular metrics of model.
|
156
|
+
|
157
|
+
Args:
|
158
|
+
confidence_threshold (float): confidence threshold, applicable for classification and detection. Default is 0.5
|
159
|
+
iou_threshold (float): iou threshold, support in range(0.5, 1., step=0.1) applicable for detection
|
160
|
+
area (float): size of area, support {all, small, medium}, applicable for detection
|
161
|
+
|
162
|
+
Return:
|
163
|
+
None or tuple of dataframe: df summary per concept and total concepts
|
164
|
+
|
165
|
+
"""
|
166
|
+
df = []
|
167
|
+
total = []
|
168
|
+
# loop over all eval_handlers/dataset and call its method
|
169
|
+
outs, comparators = self._loop_eval_handlers(
|
170
|
+
'detailed_summary',
|
171
|
+
confidence_threshold=confidence_threshold,
|
172
|
+
iou_threshold=iou_threshold,
|
173
|
+
area=area,
|
174
|
+
bypass_const=bypass_const)
|
175
|
+
for indx, out in enumerate(outs):
|
176
|
+
_df, _total = out
|
177
|
+
_df[self.comparator] = [comparators[indx] for _ in range(len(_df))]
|
178
|
+
_total['Concept'].replace(
|
179
|
+
to_replace=['Total'], value=f'{self.comparator}:{comparators[indx]}', inplace=True)
|
180
|
+
_total.rename({'Concept': 'Total Concept'}, axis=1, inplace=True)
|
181
|
+
df.append(_df)
|
182
|
+
total.append(_total)
|
183
|
+
|
184
|
+
if df:
|
185
|
+
df = pd.concat(df, axis=0)
|
186
|
+
total = pd.concat(total, axis=0)
|
187
|
+
return df, total
|
188
|
+
else:
|
189
|
+
return None
|
190
|
+
|
191
|
+
def confusion_matrix(self, show=True, save_path: str = None,
|
192
|
+
cm_kwargs: dict = {}) -> Union[pd.DataFrame, None]:
|
193
|
+
"""Return dataframe of confusion matrix
|
194
|
+
Args:
|
195
|
+
show (bool, optional): Show the chart. Defaults to True.
|
196
|
+
save_path (str): path to save rendered chart.
|
197
|
+
cm_kwargs (dict): keyword args of `eval_handler[...].model.cm_kwargs` method.
|
198
|
+
Returns:
|
199
|
+
None or pd.Dataframe, If models don't have confusion matrix, return None
|
200
|
+
"""
|
201
|
+
outs, comparators = self._loop_eval_handlers("confusion_matrix", **cm_kwargs)
|
202
|
+
all_dfs = []
|
203
|
+
for _, (df, anchor) in enumerate(zip(outs, comparators)):
|
204
|
+
df[self.comparator] = [anchor for _ in range(len(df))]
|
205
|
+
all_dfs.append(df)
|
206
|
+
|
207
|
+
if all_dfs:
|
208
|
+
all_dfs = pd.concat(all_dfs, axis=0)
|
209
|
+
if save_path or show:
|
210
|
+
|
211
|
+
def _facet_heatmap(data, **kws):
|
212
|
+
data = data.dropna(axis=1)
|
213
|
+
data = data.drop(self.comparator, axis=1)
|
214
|
+
concepts = data.columns
|
215
|
+
colnames = pd.MultiIndex.from_arrays([concepts], names=['Predicted'])
|
216
|
+
data.columns = colnames
|
217
|
+
ax = sns.heatmap(data, cmap='Blues', annot=True, annot_kws={"fontsize": 8}, **kws)
|
218
|
+
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, fontsize=6)
|
219
|
+
ax.set_yticklabels(ax.get_yticklabels(), fontsize=6, rotation=0)
|
220
|
+
|
221
|
+
temp = all_dfs.copy()
|
222
|
+
temp.columns = ["_".join(pair) for pair in temp.columns]
|
223
|
+
with sns.plotting_context(font_scale=5.5):
|
224
|
+
g = sns.FacetGrid(
|
225
|
+
temp,
|
226
|
+
col=self.comparator,
|
227
|
+
col_wrap=3,
|
228
|
+
aspect=1,
|
229
|
+
height=3,
|
230
|
+
sharex=False,
|
231
|
+
sharey=False,
|
232
|
+
)
|
233
|
+
cbar_ax = g.figure.add_axes([.92, .3, .02, .4])
|
234
|
+
g = g.map_dataframe(
|
235
|
+
_facet_heatmap, cbar_ax=cbar_ax, vmin=0, vmax=1, cbar=True, square=True)
|
236
|
+
g.set_titles(col_template=str(self.comparator) + ':{col_name}', fontsize=5)
|
237
|
+
if show:
|
238
|
+
plt.show()
|
239
|
+
if save_path:
|
240
|
+
g.savefig(save_path)
|
241
|
+
|
242
|
+
return all_dfs if isinstance(all_dfs, pd.DataFrame) else None
|
243
|
+
|
244
|
+
@staticmethod
|
245
|
+
def _set_default_kwargs(kwargs: dict, var_name: str, value):
|
246
|
+
if var_name not in kwargs:
|
247
|
+
kwargs.update({var_name: value})
|
248
|
+
return kwargs
|
249
|
+
|
250
|
+
@staticmethod
|
251
|
+
def _setup_default_lineplot(df: pd.DataFrame, kwargs: dict):
|
252
|
+
hue_order = df["concept"].unique().tolist()
|
253
|
+
hue_order.remove(MACRO_AVG)
|
254
|
+
hue_order.insert(0, MACRO_AVG)
|
255
|
+
EvalResultCompare._set_default_kwargs(kwargs, "hue_order", hue_order)
|
256
|
+
|
257
|
+
sizes = {}
|
258
|
+
for each in hue_order:
|
259
|
+
s = 1.5
|
260
|
+
if each == MACRO_AVG:
|
261
|
+
s = 4.
|
262
|
+
sizes.update({each: s})
|
263
|
+
EvalResultCompare._set_default_kwargs(kwargs, "sizes", sizes)
|
264
|
+
EvalResultCompare._set_default_kwargs(kwargs, "size", "concept")
|
265
|
+
|
266
|
+
EvalResultCompare._set_default_kwargs(kwargs, "errorbar", None)
|
267
|
+
EvalResultCompare._set_default_kwargs(kwargs, "height", 5)
|
268
|
+
|
269
|
+
return kwargs
|
270
|
+
|
271
|
+
def roc_curve_plot(self,
|
272
|
+
show=True,
|
273
|
+
save_path: str = None,
|
274
|
+
roc_curve_kwargs: dict = {},
|
275
|
+
relplot_kwargs: dict = {}) -> Union[pd.DataFrame, None]:
|
276
|
+
"""Return dataframe of ROC curve
|
277
|
+
Args:
|
278
|
+
show (bool, optional): Show the chart. Defaults to True.
|
279
|
+
save_path (str): path to save rendered chart.
|
280
|
+
pr_curve_kwargs (dict): keyword args of `eval_handler[...].model.roc_curve` method.
|
281
|
+
relplot_kwargs (dict): keyword args of `sns.relplot` except {data,x,y,hue,kind,col}. where x="fpr", y="tpr", hue="concept"
|
282
|
+
Returns:
|
283
|
+
None or pd.Dataframe, If models don't have ROC curve, return None
|
284
|
+
"""
|
285
|
+
sns.set_palette("Paired")
|
286
|
+
outs, comparator = self._loop_eval_handlers("roc_curve", **roc_curve_kwargs)
|
287
|
+
all_dfs = []
|
288
|
+
for _, (df, anchor) in enumerate(zip(outs, comparator)):
|
289
|
+
df[self.comparator] = [anchor for _ in range(len(df))]
|
290
|
+
all_dfs.append(df)
|
291
|
+
|
292
|
+
if all_dfs:
|
293
|
+
all_dfs = pd.concat(all_dfs, axis=0)
|
294
|
+
if save_path or show:
|
295
|
+
relplot_kwargs = self._setup_default_lineplot(all_dfs, relplot_kwargs)
|
296
|
+
g = sns.relplot(
|
297
|
+
data=all_dfs,
|
298
|
+
x="fpr",
|
299
|
+
y="tpr",
|
300
|
+
hue='concept',
|
301
|
+
kind="line",
|
302
|
+
col=self.comparator,
|
303
|
+
**relplot_kwargs)
|
304
|
+
g.set_titles(col_template=str(self.comparator) + ':{col_name}', fontsize=5)
|
305
|
+
if show:
|
306
|
+
plt.show()
|
307
|
+
if save_path:
|
308
|
+
g.savefig(save_path)
|
309
|
+
|
310
|
+
return all_dfs if isinstance(all_dfs, pd.DataFrame) else None
|
311
|
+
|
312
|
+
def pr_plot(self,
|
313
|
+
show=True,
|
314
|
+
save_path: str = None,
|
315
|
+
pr_curve_kwargs: dict = {},
|
316
|
+
relplot_kwargs: dict = {}) -> Union[pd.DataFrame, None]:
|
317
|
+
"""Return dataframe of PR curve
|
318
|
+
Args:
|
319
|
+
show (bool, optional): Show the chart. Defaults to True.
|
320
|
+
save_path (str): path to save rendered chart.
|
321
|
+
pr_curve_kwargs (dict): keyword args of `eval_handler[...].model.pr_curve` method.
|
322
|
+
relplot_kwargs (dict): keyword args of `sns.relplot` except {data,x,y,hue,kind,col} where x="recall", y="precision", hue="concept"
|
323
|
+
Returns:
|
324
|
+
None or pd.Dataframe, If models don't have PR curve, return None
|
325
|
+
"""
|
326
|
+
sns.set_palette("Paired")
|
327
|
+
outs, comparator = self._loop_eval_handlers("pr_curve", **pr_curve_kwargs)
|
328
|
+
all_dfs = []
|
329
|
+
for _, (df, anchor) in enumerate(zip(outs, comparator)):
|
330
|
+
df[self.comparator] = [anchor for _ in range(len(df))]
|
331
|
+
all_dfs.append(df)
|
332
|
+
|
333
|
+
if all_dfs:
|
334
|
+
all_dfs = pd.concat(all_dfs, axis=0)
|
335
|
+
if save_path or show:
|
336
|
+
relplot_kwargs = self._setup_default_lineplot(all_dfs, relplot_kwargs)
|
337
|
+
g = sns.relplot(
|
338
|
+
data=all_dfs,
|
339
|
+
x="recall",
|
340
|
+
y="precision",
|
341
|
+
hue='concept',
|
342
|
+
kind="line",
|
343
|
+
col=self.comparator,
|
344
|
+
**relplot_kwargs)
|
345
|
+
g.set_titles(col_template=str(self.comparator) + ':{col_name}', fontsize=5)
|
346
|
+
if show:
|
347
|
+
plt.show()
|
348
|
+
if save_path:
|
349
|
+
g.savefig(save_path)
|
350
|
+
|
351
|
+
return all_dfs if isinstance(all_dfs, pd.DataFrame) else None
|
352
|
+
|
353
|
+
def all(
|
354
|
+
self,
|
355
|
+
output_folder: str,
|
356
|
+
confidence_threshold: float = 0.5,
|
357
|
+
iou_threshold: float = 0.5,
|
358
|
+
overwrite: bool = False,
|
359
|
+
metric_kwargs: dict = {},
|
360
|
+
pr_plot_kwargs: dict = {},
|
361
|
+
roc_plot_kwargs: dict = {},
|
362
|
+
):
|
363
|
+
"""Run all comparison methods one by one:
|
364
|
+
- detailed_summary
|
365
|
+
- pr_curve (if applicable)
|
366
|
+
- pr_plot
|
367
|
+
- confusion_matrix (if applicable)
|
368
|
+
And save to output_folder
|
369
|
+
|
370
|
+
Args:
|
371
|
+
output_folder (str): path to output
|
372
|
+
confidence_threshold (float): confidence threshold, applicable for classification and detection. Default is 0.5.
|
373
|
+
iou_threshold (float): iou threshold, support in range(0.5, 1., step=0.1) applicable for detection.
|
374
|
+
overwrite (bool): overwrite result of output_folder.
|
375
|
+
metric_kwargs (dict): keyword args for `eval_handler[...].model.{method}`, except for {confidence_threshold, iou_threshold}.
|
376
|
+
roc_plot_kwargs (dict): for relplot_kwargs of `roc_curve_plot` method.
|
377
|
+
pr_plot_kwargs (dict): for relplot_kwargs of `pr_plot` method.
|
378
|
+
"""
|
379
|
+
eval_type = get_eval_type(self.model_type)
|
380
|
+
area = metric_kwargs.pop("area", "all")
|
381
|
+
bypass_const = metric_kwargs.pop("bypass_const", False)
|
382
|
+
|
383
|
+
fname = f"conf-{confidence_threshold}"
|
384
|
+
if eval_type == EvalType.DETECTION:
|
385
|
+
fname = f"{fname}_iou-{iou_threshold}_area-{area}"
|
386
|
+
|
387
|
+
def join_root(*args):
|
388
|
+
return os.path.join(output_folder, *args)
|
389
|
+
|
390
|
+
output_folder = join_root(fname)
|
391
|
+
if os.path.exists(output_folder) and not overwrite:
|
392
|
+
raise RuntimeError(f"{output_folder} exists. If you want to overwrite, set `overwrite=True`")
|
393
|
+
|
394
|
+
os.makedirs(output_folder, exist_ok=True)
|
395
|
+
|
396
|
+
logger.info("Making summary tables...")
|
397
|
+
dfs = self.detailed_summary(
|
398
|
+
confidence_threshold=confidence_threshold,
|
399
|
+
iou_threshold=iou_threshold,
|
400
|
+
area=area,
|
401
|
+
bypass_const=bypass_const)
|
402
|
+
if dfs is not None:
|
403
|
+
concept_df, total_df = dfs
|
404
|
+
concept_df.to_csv(join_root("concepts_summary.csv"))
|
405
|
+
total_df.to_csv(join_root("total_summary.csv"))
|
406
|
+
|
407
|
+
curve_metric_kwargs = dict(
|
408
|
+
confidence_threshold=confidence_threshold, iou_threshold=iou_threshold)
|
409
|
+
curve_metric_kwargs.update(metric_kwargs)
|
410
|
+
|
411
|
+
self.roc_curve_plot(
|
412
|
+
show=False,
|
413
|
+
save_path=join_root("roc.jpg"),
|
414
|
+
roc_curve_kwargs=curve_metric_kwargs,
|
415
|
+
relplot_kwargs=roc_plot_kwargs)
|
416
|
+
|
417
|
+
self.pr_plot(
|
418
|
+
show=False,
|
419
|
+
save_path=join_root("pr.jpg"),
|
420
|
+
pr_curve_kwargs=curve_metric_kwargs,
|
421
|
+
relplot_kwargs=pr_plot_kwargs)
|
422
|
+
|
423
|
+
self.confusion_matrix(
|
424
|
+
show=False, save_path=join_root("confusion_matrix.jpg"), cm_kwargs=curve_metric_kwargs)
|
425
|
+
|
426
|
+
logger.info(f"Done. Your outputs are saved at {output_folder}")
|
@@ -0,0 +1,150 @@
|
|
1
|
+
from typing import List, Tuple
|
2
|
+
import numpy as np
|
3
|
+
from clarifai_grpc.grpc.api import resources_pb2 as respb2
|
4
|
+
|
5
|
+
|
6
|
+
def parse_eval_annotation_classifier(
|
7
|
+
eval_metrics: respb2.EvalMetrics) -> Tuple[np.array, np.array, list, List[respb2.Input]]:
|
8
|
+
test_set = eval_metrics.test_set
|
9
|
+
# get concept ids
|
10
|
+
concept_ids = [each.id for each in test_set[0].predicted_concepts]
|
11
|
+
concept_ids.sort()
|
12
|
+
# get test set
|
13
|
+
y_preds = []
|
14
|
+
y = []
|
15
|
+
inputs = []
|
16
|
+
for data in test_set:
|
17
|
+
|
18
|
+
def _to_array(_data):
|
19
|
+
cps = [0] * len(concept_ids)
|
20
|
+
for each in _data:
|
21
|
+
cps[concept_ids.index(each.id)] = each.value
|
22
|
+
return np.asarray(cps)
|
23
|
+
|
24
|
+
y_preds.append(_to_array(data.predicted_concepts))
|
25
|
+
y.append(_to_array(data.ground_truth_concepts))
|
26
|
+
inputs.append(data.input)
|
27
|
+
|
28
|
+
return np.asarray(y), np.asarray(y_preds), concept_ids, inputs
|
29
|
+
|
30
|
+
|
31
|
+
def parse_eval_annotation_detector(eval_metrics: respb2.EvalMetrics,
|
32
|
+
normalized_box: bool = False,
|
33
|
+
box_style: str = "xyxy"
|
34
|
+
) -> Tuple[np.array, np.array, list, List[respb2.Input]]:
|
35
|
+
BOX_STYLES = ["xyxy", "xywh"]
|
36
|
+
assert box_style in BOX_STYLES, ValueError(f"Expected box_style in {BOX_STYLES}")
|
37
|
+
|
38
|
+
concept_ids = list(set([each.concept.id for each in eval_metrics.metrics_by_class]))
|
39
|
+
concept_ids.sort()
|
40
|
+
|
41
|
+
def _get_box_annot(field, img_height, img_width):
|
42
|
+
xyxy_concept_score = []
|
43
|
+
for each in field:
|
44
|
+
box = each.region_info.bounding_box
|
45
|
+
x1 = box.left_col * img_width
|
46
|
+
y1 = box.top_row * img_height
|
47
|
+
x2 = box.right_col * img_width
|
48
|
+
y2 = box.bottom_row * img_height
|
49
|
+
score = each.data.concepts[0].value
|
50
|
+
concept = each.data.concepts[0].id
|
51
|
+
concept_index = concept_ids.index(concept)
|
52
|
+
if box_style == "xyxy":
|
53
|
+
xyxy_concept_score.append([x1, y1, x2, y2, concept_index, score])
|
54
|
+
else:
|
55
|
+
w = abs(x1 - x2)
|
56
|
+
h = abs(y1 - y2)
|
57
|
+
xyxy_concept_score.append([x1, y1, w, h, concept_index, score])
|
58
|
+
|
59
|
+
return np.asarray(xyxy_concept_score)
|
60
|
+
|
61
|
+
inputs = []
|
62
|
+
pred_xyxy_concept_score = []
|
63
|
+
gt_xyxy_concept_score = []
|
64
|
+
for input_data in eval_metrics.test_set:
|
65
|
+
_input = input_data.input
|
66
|
+
img_height = _input.data.image.image_info.height if not normalized_box else 1.
|
67
|
+
img_width = _input.data.image.image_info.height if not normalized_box else 1.
|
68
|
+
_pred_xyxy_concept_score = _get_box_annot(
|
69
|
+
input_data.predicted_annotation.data.regions, img_height=img_height, img_width=img_width)
|
70
|
+
_gt_xyxy_concept_score = _get_box_annot(
|
71
|
+
input_data.ground_truth_annotation.data.regions,
|
72
|
+
img_height=img_height,
|
73
|
+
img_width=img_width)
|
74
|
+
|
75
|
+
pred_xyxy_concept_score.append(_pred_xyxy_concept_score)
|
76
|
+
gt_xyxy_concept_score.append(_gt_xyxy_concept_score)
|
77
|
+
inputs.append(_input)
|
78
|
+
|
79
|
+
return np.asarray(gt_xyxy_concept_score), np.asarray(
|
80
|
+
pred_xyxy_concept_score), concept_ids, inputs
|
81
|
+
|
82
|
+
|
83
|
+
def parse_eval_annotation_detector_coco(eval_metrics: respb2.EvalMetrics,
|
84
|
+
) -> Tuple[np.array, np.array, list, List[respb2.Input]]:
|
85
|
+
|
86
|
+
gts, preds, concept_ids, inputs = parse_eval_annotation_detector(
|
87
|
+
eval_metrics=eval_metrics, normalized_box=False, box_style="xywh")
|
88
|
+
|
89
|
+
def _make_box_annot(data, input_data, accum_id, is_pred=True):
|
90
|
+
img_id = input_data.id
|
91
|
+
img_url = input_data.data.image.url
|
92
|
+
img_height = input_data.data.image.image_info.height
|
93
|
+
img_width = input_data.data.image.image_info.height
|
94
|
+
image = {
|
95
|
+
"id": img_id,
|
96
|
+
"file_name": img_url,
|
97
|
+
"width": img_width,
|
98
|
+
"height": img_height,
|
99
|
+
}
|
100
|
+
annotations = []
|
101
|
+
for i, d in enumerate(data.tolist()):
|
102
|
+
area = d[2] * d[3]
|
103
|
+
box = {
|
104
|
+
"iscrowd": 0,
|
105
|
+
"ignore": 0,
|
106
|
+
"image_id": img_id,
|
107
|
+
"bbox": d[:4],
|
108
|
+
"area": area,
|
109
|
+
"segmentation": [],
|
110
|
+
"category_id": d[4],
|
111
|
+
"id": accum_id + i,
|
112
|
+
}
|
113
|
+
if is_pred:
|
114
|
+
box["score"] = d[5]
|
115
|
+
annotations.append(box)
|
116
|
+
|
117
|
+
return image, annotations
|
118
|
+
|
119
|
+
categories = [{
|
120
|
+
"supercategory": "none",
|
121
|
+
"id": i,
|
122
|
+
"name": label
|
123
|
+
} for i, label in enumerate(concept_ids)]
|
124
|
+
|
125
|
+
accum_pred_ids = 0
|
126
|
+
accum_gt_ids = 0
|
127
|
+
pred_images = []
|
128
|
+
pred_boxes = []
|
129
|
+
gt_images = []
|
130
|
+
gt_boxes = []
|
131
|
+
for ith, input_proto in enumerate(inputs):
|
132
|
+
pred = preds[ith]
|
133
|
+
gt = gts[ith]
|
134
|
+
pred_img, pred_box = _make_box_annot(
|
135
|
+
pred, accum_id=accum_pred_ids, input_data=input_proto, is_pred=True)
|
136
|
+
gt_img, gt_box = _make_box_annot(
|
137
|
+
gt, accum_id=accum_gt_ids, input_data=input_proto, is_pred=False)
|
138
|
+
|
139
|
+
accum_pred_ids += len(pred)
|
140
|
+
pred_images.append(pred_img)
|
141
|
+
pred_boxes += pred_box
|
142
|
+
|
143
|
+
accum_gt_ids += len(gt)
|
144
|
+
gt_images.append(gt_img)
|
145
|
+
gt_boxes += gt_box
|
146
|
+
|
147
|
+
pred_annots = {"images": pred_images, "annotations": pred_boxes, "categories": categories}
|
148
|
+
gt_annots = {"images": gt_images, "annotations": gt_boxes, "categories": categories}
|
149
|
+
|
150
|
+
return gt_annots, pred_annots
|
clarifai/versions.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: clarifai
|
3
|
-
Version: 10.
|
3
|
+
Version: 10.3.0
|
4
4
|
Summary: Clarifai Python SDK
|
5
5
|
Home-page: https://github.com/Clarifai/clarifai-python
|
6
6
|
Author: Clarifai
|
@@ -20,18 +20,18 @@ Classifier: Operating System :: OS Independent
|
|
20
20
|
Requires-Python: >=3.8
|
21
21
|
Description-Content-Type: text/markdown
|
22
22
|
License-File: LICENSE
|
23
|
-
Requires-Dist: clarifai-grpc
|
24
|
-
Requires-Dist: numpy
|
25
|
-
Requires-Dist: tqdm
|
26
|
-
Requires-Dist: tritonclient
|
27
|
-
Requires-Dist: rich
|
28
|
-
Requires-Dist: PyYAML
|
29
|
-
Requires-Dist: schema
|
30
|
-
Requires-Dist: Pillow
|
31
|
-
Requires-Dist: inquirerpy
|
32
|
-
Requires-Dist: tabulate
|
23
|
+
Requires-Dist: clarifai-grpc ~=10.2.3
|
24
|
+
Requires-Dist: numpy >=1.22.0
|
25
|
+
Requires-Dist: tqdm >=4.65.0
|
26
|
+
Requires-Dist: tritonclient >=2.34.0
|
27
|
+
Requires-Dist: rich >=13.4.2
|
28
|
+
Requires-Dist: PyYAML >=6.0.1
|
29
|
+
Requires-Dist: schema >=0.7.5
|
30
|
+
Requires-Dist: Pillow >=9.5.0
|
31
|
+
Requires-Dist: inquirerpy ==0.3.4
|
32
|
+
Requires-Dist: tabulate >=0.9.0
|
33
33
|
Provides-Extra: all
|
34
|
-
Requires-Dist: pycocotools
|
34
|
+
Requires-Dist: pycocotools ==2.0.6 ; extra == 'all'
|
35
35
|
|
36
36
|
<h1 align="center">
|
37
37
|
<a href="https://www.clarifai.com/"><img alt="Clarifai" title="Clarifai" src="https://upload.wikimedia.org/wikipedia/commons/b/bc/Clarifai_Logo_FC_Web.png"></a>
|