clarifai 10.2.1__py3-none-any.whl → 10.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,426 @@
1
+ import os
2
+ from enum import Enum
3
+ from typing import List, Tuple, Union
4
+
5
+ from clarifai.client.dataset import Dataset
6
+ from clarifai.client.model import Model
7
+
8
+ from .helpers import (MACRO_AVG, EvalType, _BaseEvalResultHandler, get_eval_type,
9
+ make_handler_by_type)
10
+
11
+ try:
12
+ import seaborn as sns
13
+ except ImportError:
14
+ raise ImportError("Can not import seaborn. Please run `pip install seaborn` to install it")
15
+
16
+ try:
17
+ import matplotlib.pyplot as plt
18
+ except ImportError:
19
+ raise ImportError("Can not import matplotlib. Please run `pip install matplotlib` to install it")
20
+
21
+ try:
22
+ import pandas as pd
23
+ except ImportError:
24
+ raise ImportError("Can not import pandas. Please run `pip install pandas` to install it")
25
+
26
+ try:
27
+ from loguru import logger
28
+ except ImportError:
29
+ from ..logging import get_logger
30
+ logger = get_logger(logger_level="INFO", name=__name__)
31
+
32
+ __all__ = ['EvalResultCompare']
33
+
34
+
35
+ class CompareMode(Enum):
36
+ MANY_MODELS_TO_ONE_DATA = 0
37
+ ONE_MODEL_TO_MANY_DATA = 1
38
+
39
+
40
+ class EvalResultCompare:
41
+ """Compare evaluation result of models against datasets.
42
+ Note: The module will pick latest result on the datasets.
43
+ and models must be same model type
44
+
45
+ Args:
46
+ ---
47
+ models (Union[List[Model], List[str]]): List of Model or urls of models.
48
+ datasets (Union[Dataset, List[Dataset], str, List[str]]): A single or List of Url or Dataset
49
+ attempt_evaluate (bool): Evaluate when model is not evaluated with the datasets.
50
+ auth_kwargs (dict): Additional auth keyword arguments to be passed to the Dataset and Model if using url(s)
51
+ """
52
+
53
+ def __init__(self,
54
+ models: Union[List[Model], List[str]],
55
+ datasets: Union[Dataset, List[Dataset], str, List[str]],
56
+ attempt_evaluate: bool = False,
57
+ auth_kwargs: dict = {}):
58
+ assert isinstance(models, list), ValueError("Expected list")
59
+
60
+ if len(models) > 1:
61
+ self.mode = CompareMode.MANY_MODELS_TO_ONE_DATA
62
+ self.comparator = "Model"
63
+ assert isinstance(datasets, Dataset) or (
64
+ isinstance(datasets, list) and len(datasets) == 1
65
+ ), f"When comparing multiple models, must provide only one `datasets`. However got {datasets}"
66
+ else:
67
+ self.mode = CompareMode.ONE_MODEL_TO_MANY_DATA
68
+ self.comparator = "Dataset"
69
+
70
+ # validate models
71
+ if all(map(lambda x: isinstance(x, str), models)):
72
+ models = [Model(each, **auth_kwargs) for each in models]
73
+ elif not all(map(lambda x: isinstance(x, Model), models)):
74
+ raise ValueError(
75
+ f"Expected all models are list of string or list of Model, got {[type(each) for each in models]}"
76
+ )
77
+ # validate datasets
78
+ if not isinstance(datasets, list):
79
+ datasets = [
80
+ datasets,
81
+ ]
82
+ if all(map(lambda x: isinstance(x, str), datasets)):
83
+ datasets = [Dataset(each, **auth_kwargs) for each in datasets]
84
+ elif not all(map(lambda x: isinstance(x, Dataset), datasets)):
85
+ raise ValueError(
86
+ f"Expected datasets must be str, list of string or Dataset, list of Dataset, got {[type(each) for each in datasets]}"
87
+ )
88
+ # Validate models vs datasets together
89
+ self._eval_handlers: List[_BaseEvalResultHandler] = []
90
+ self.model_type = None
91
+ logger.info("Initializing models...")
92
+ for model in models:
93
+ model.load_info()
94
+ model_type = model.model_info.model_type_id
95
+ if not self.model_type:
96
+ self.model_type = model_type
97
+ else:
98
+ assert self.model_type == model_type, f"Can not compare when model types are different, {self.model_type} != {model_type}"
99
+ m = make_handler_by_type(model_type)(model=model)
100
+ logger.info(f"* {m.get_model_name(pretify=True)}")
101
+ m.find_eval_id(datasets=datasets, attempt_evaluate=attempt_evaluate)
102
+ self._eval_handlers.append(m)
103
+
104
+ @property
105
+ def eval_handlers(self):
106
+ return self._eval_handlers
107
+
108
+ def _loop_eval_handlers(self, func_name: str, **kwargs) -> Tuple[list, list]:
109
+ """ Run methods of `eval_handlers[...].model`
110
+
111
+ Args:
112
+ func_name (str): method name, see `_BaseEvalResultHandler` child classes
113
+ kwargs: keyword arguments of the method
114
+
115
+ Return:
116
+ tuple:
117
+ - list of outputs
118
+ - list of comparator names
119
+
120
+ """
121
+ outs = []
122
+ comparators = []
123
+ logger.info(f'Running `{func_name}`')
124
+ for _, each in enumerate(self.eval_handlers):
125
+ for ds_index, _ in enumerate(each.eval_data):
126
+ func = eval(f'each.{func_name}')
127
+ out = func(index=ds_index, **kwargs)
128
+
129
+ if self.mode == CompareMode.MANY_MODELS_TO_ONE_DATA:
130
+ name = each.get_model_name(pretify=True)
131
+ else:
132
+ name = each.get_dataset_name_by_index(ds_index, pretify=True)
133
+ if out is None:
134
+ logger.warning(f'{self.comparator}:{name} does not have valid data for `{func_name}`')
135
+ continue
136
+ comparators.append(name)
137
+ outs.append(out)
138
+
139
+ if self.mode == CompareMode.MANY_MODELS_TO_ONE_DATA:
140
+ apps = set([comp.split('/')[0] for comp in comparators])
141
+ if len(apps) == 1:
142
+ comparators = ['/'.join(comp.split('/')[1:]) for comp in comparators]
143
+
144
+ if not outs:
145
+ logger.warning(f'Model type {self.model_type} does not support `{func_name}`')
146
+
147
+ return outs, comparators
148
+
149
+ def detailed_summary(self,
150
+ confidence_threshold: float = .5,
151
+ iou_threshold: float = .5,
152
+ area: str = "all",
153
+ bypass_const=False) -> Union[Tuple[pd.DataFrame, pd.DataFrame], None]:
154
+ """
155
+ Retrieve and compute popular metrics of model.
156
+
157
+ Args:
158
+ confidence_threshold (float): confidence threshold, applicable for classification and detection. Default is 0.5
159
+ iou_threshold (float): iou threshold, support in range(0.5, 1., step=0.1) applicable for detection
160
+ area (float): size of area, support {all, small, medium}, applicable for detection
161
+
162
+ Return:
163
+ None or tuple of dataframe: df summary per concept and total concepts
164
+
165
+ """
166
+ df = []
167
+ total = []
168
+ # loop over all eval_handlers/dataset and call its method
169
+ outs, comparators = self._loop_eval_handlers(
170
+ 'detailed_summary',
171
+ confidence_threshold=confidence_threshold,
172
+ iou_threshold=iou_threshold,
173
+ area=area,
174
+ bypass_const=bypass_const)
175
+ for indx, out in enumerate(outs):
176
+ _df, _total = out
177
+ _df[self.comparator] = [comparators[indx] for _ in range(len(_df))]
178
+ _total['Concept'].replace(
179
+ to_replace=['Total'], value=f'{self.comparator}:{comparators[indx]}', inplace=True)
180
+ _total.rename({'Concept': 'Total Concept'}, axis=1, inplace=True)
181
+ df.append(_df)
182
+ total.append(_total)
183
+
184
+ if df:
185
+ df = pd.concat(df, axis=0)
186
+ total = pd.concat(total, axis=0)
187
+ return df, total
188
+ else:
189
+ return None
190
+
191
+ def confusion_matrix(self, show=True, save_path: str = None,
192
+ cm_kwargs: dict = {}) -> Union[pd.DataFrame, None]:
193
+ """Return dataframe of confusion matrix
194
+ Args:
195
+ show (bool, optional): Show the chart. Defaults to True.
196
+ save_path (str): path to save rendered chart.
197
+ cm_kwargs (dict): keyword args of `eval_handler[...].model.cm_kwargs` method.
198
+ Returns:
199
+ None or pd.Dataframe, If models don't have confusion matrix, return None
200
+ """
201
+ outs, comparators = self._loop_eval_handlers("confusion_matrix", **cm_kwargs)
202
+ all_dfs = []
203
+ for _, (df, anchor) in enumerate(zip(outs, comparators)):
204
+ df[self.comparator] = [anchor for _ in range(len(df))]
205
+ all_dfs.append(df)
206
+
207
+ if all_dfs:
208
+ all_dfs = pd.concat(all_dfs, axis=0)
209
+ if save_path or show:
210
+
211
+ def _facet_heatmap(data, **kws):
212
+ data = data.dropna(axis=1)
213
+ data = data.drop(self.comparator, axis=1)
214
+ concepts = data.columns
215
+ colnames = pd.MultiIndex.from_arrays([concepts], names=['Predicted'])
216
+ data.columns = colnames
217
+ ax = sns.heatmap(data, cmap='Blues', annot=True, annot_kws={"fontsize": 8}, **kws)
218
+ ax.set_xticklabels(ax.get_xticklabels(), rotation=45, fontsize=6)
219
+ ax.set_yticklabels(ax.get_yticklabels(), fontsize=6, rotation=0)
220
+
221
+ temp = all_dfs.copy()
222
+ temp.columns = ["_".join(pair) for pair in temp.columns]
223
+ with sns.plotting_context(font_scale=5.5):
224
+ g = sns.FacetGrid(
225
+ temp,
226
+ col=self.comparator,
227
+ col_wrap=3,
228
+ aspect=1,
229
+ height=3,
230
+ sharex=False,
231
+ sharey=False,
232
+ )
233
+ cbar_ax = g.figure.add_axes([.92, .3, .02, .4])
234
+ g = g.map_dataframe(
235
+ _facet_heatmap, cbar_ax=cbar_ax, vmin=0, vmax=1, cbar=True, square=True)
236
+ g.set_titles(col_template=str(self.comparator) + ':{col_name}', fontsize=5)
237
+ if show:
238
+ plt.show()
239
+ if save_path:
240
+ g.savefig(save_path)
241
+
242
+ return all_dfs if isinstance(all_dfs, pd.DataFrame) else None
243
+
244
+ @staticmethod
245
+ def _set_default_kwargs(kwargs: dict, var_name: str, value):
246
+ if var_name not in kwargs:
247
+ kwargs.update({var_name: value})
248
+ return kwargs
249
+
250
+ @staticmethod
251
+ def _setup_default_lineplot(df: pd.DataFrame, kwargs: dict):
252
+ hue_order = df["concept"].unique().tolist()
253
+ hue_order.remove(MACRO_AVG)
254
+ hue_order.insert(0, MACRO_AVG)
255
+ EvalResultCompare._set_default_kwargs(kwargs, "hue_order", hue_order)
256
+
257
+ sizes = {}
258
+ for each in hue_order:
259
+ s = 1.5
260
+ if each == MACRO_AVG:
261
+ s = 4.
262
+ sizes.update({each: s})
263
+ EvalResultCompare._set_default_kwargs(kwargs, "sizes", sizes)
264
+ EvalResultCompare._set_default_kwargs(kwargs, "size", "concept")
265
+
266
+ EvalResultCompare._set_default_kwargs(kwargs, "errorbar", None)
267
+ EvalResultCompare._set_default_kwargs(kwargs, "height", 5)
268
+
269
+ return kwargs
270
+
271
+ def roc_curve_plot(self,
272
+ show=True,
273
+ save_path: str = None,
274
+ roc_curve_kwargs: dict = {},
275
+ relplot_kwargs: dict = {}) -> Union[pd.DataFrame, None]:
276
+ """Return dataframe of ROC curve
277
+ Args:
278
+ show (bool, optional): Show the chart. Defaults to True.
279
+ save_path (str): path to save rendered chart.
280
+ pr_curve_kwargs (dict): keyword args of `eval_handler[...].model.roc_curve` method.
281
+ relplot_kwargs (dict): keyword args of `sns.relplot` except {data,x,y,hue,kind,col}. where x="fpr", y="tpr", hue="concept"
282
+ Returns:
283
+ None or pd.Dataframe, If models don't have ROC curve, return None
284
+ """
285
+ sns.set_palette("Paired")
286
+ outs, comparator = self._loop_eval_handlers("roc_curve", **roc_curve_kwargs)
287
+ all_dfs = []
288
+ for _, (df, anchor) in enumerate(zip(outs, comparator)):
289
+ df[self.comparator] = [anchor for _ in range(len(df))]
290
+ all_dfs.append(df)
291
+
292
+ if all_dfs:
293
+ all_dfs = pd.concat(all_dfs, axis=0)
294
+ if save_path or show:
295
+ relplot_kwargs = self._setup_default_lineplot(all_dfs, relplot_kwargs)
296
+ g = sns.relplot(
297
+ data=all_dfs,
298
+ x="fpr",
299
+ y="tpr",
300
+ hue='concept',
301
+ kind="line",
302
+ col=self.comparator,
303
+ **relplot_kwargs)
304
+ g.set_titles(col_template=str(self.comparator) + ':{col_name}', fontsize=5)
305
+ if show:
306
+ plt.show()
307
+ if save_path:
308
+ g.savefig(save_path)
309
+
310
+ return all_dfs if isinstance(all_dfs, pd.DataFrame) else None
311
+
312
+ def pr_plot(self,
313
+ show=True,
314
+ save_path: str = None,
315
+ pr_curve_kwargs: dict = {},
316
+ relplot_kwargs: dict = {}) -> Union[pd.DataFrame, None]:
317
+ """Return dataframe of PR curve
318
+ Args:
319
+ show (bool, optional): Show the chart. Defaults to True.
320
+ save_path (str): path to save rendered chart.
321
+ pr_curve_kwargs (dict): keyword args of `eval_handler[...].model.pr_curve` method.
322
+ relplot_kwargs (dict): keyword args of `sns.relplot` except {data,x,y,hue,kind,col} where x="recall", y="precision", hue="concept"
323
+ Returns:
324
+ None or pd.Dataframe, If models don't have PR curve, return None
325
+ """
326
+ sns.set_palette("Paired")
327
+ outs, comparator = self._loop_eval_handlers("pr_curve", **pr_curve_kwargs)
328
+ all_dfs = []
329
+ for _, (df, anchor) in enumerate(zip(outs, comparator)):
330
+ df[self.comparator] = [anchor for _ in range(len(df))]
331
+ all_dfs.append(df)
332
+
333
+ if all_dfs:
334
+ all_dfs = pd.concat(all_dfs, axis=0)
335
+ if save_path or show:
336
+ relplot_kwargs = self._setup_default_lineplot(all_dfs, relplot_kwargs)
337
+ g = sns.relplot(
338
+ data=all_dfs,
339
+ x="recall",
340
+ y="precision",
341
+ hue='concept',
342
+ kind="line",
343
+ col=self.comparator,
344
+ **relplot_kwargs)
345
+ g.set_titles(col_template=str(self.comparator) + ':{col_name}', fontsize=5)
346
+ if show:
347
+ plt.show()
348
+ if save_path:
349
+ g.savefig(save_path)
350
+
351
+ return all_dfs if isinstance(all_dfs, pd.DataFrame) else None
352
+
353
+ def all(
354
+ self,
355
+ output_folder: str,
356
+ confidence_threshold: float = 0.5,
357
+ iou_threshold: float = 0.5,
358
+ overwrite: bool = False,
359
+ metric_kwargs: dict = {},
360
+ pr_plot_kwargs: dict = {},
361
+ roc_plot_kwargs: dict = {},
362
+ ):
363
+ """Run all comparison methods one by one:
364
+ - detailed_summary
365
+ - pr_curve (if applicable)
366
+ - pr_plot
367
+ - confusion_matrix (if applicable)
368
+ And save to output_folder
369
+
370
+ Args:
371
+ output_folder (str): path to output
372
+ confidence_threshold (float): confidence threshold, applicable for classification and detection. Default is 0.5.
373
+ iou_threshold (float): iou threshold, support in range(0.5, 1., step=0.1) applicable for detection.
374
+ overwrite (bool): overwrite result of output_folder.
375
+ metric_kwargs (dict): keyword args for `eval_handler[...].model.{method}`, except for {confidence_threshold, iou_threshold}.
376
+ roc_plot_kwargs (dict): for relplot_kwargs of `roc_curve_plot` method.
377
+ pr_plot_kwargs (dict): for relplot_kwargs of `pr_plot` method.
378
+ """
379
+ eval_type = get_eval_type(self.model_type)
380
+ area = metric_kwargs.pop("area", "all")
381
+ bypass_const = metric_kwargs.pop("bypass_const", False)
382
+
383
+ fname = f"conf-{confidence_threshold}"
384
+ if eval_type == EvalType.DETECTION:
385
+ fname = f"{fname}_iou-{iou_threshold}_area-{area}"
386
+
387
+ def join_root(*args):
388
+ return os.path.join(output_folder, *args)
389
+
390
+ output_folder = join_root(fname)
391
+ if os.path.exists(output_folder) and not overwrite:
392
+ raise RuntimeError(f"{output_folder} exists. If you want to overwrite, set `overwrite=True`")
393
+
394
+ os.makedirs(output_folder, exist_ok=True)
395
+
396
+ logger.info("Making summary tables...")
397
+ dfs = self.detailed_summary(
398
+ confidence_threshold=confidence_threshold,
399
+ iou_threshold=iou_threshold,
400
+ area=area,
401
+ bypass_const=bypass_const)
402
+ if dfs is not None:
403
+ concept_df, total_df = dfs
404
+ concept_df.to_csv(join_root("concepts_summary.csv"))
405
+ total_df.to_csv(join_root("total_summary.csv"))
406
+
407
+ curve_metric_kwargs = dict(
408
+ confidence_threshold=confidence_threshold, iou_threshold=iou_threshold)
409
+ curve_metric_kwargs.update(metric_kwargs)
410
+
411
+ self.roc_curve_plot(
412
+ show=False,
413
+ save_path=join_root("roc.jpg"),
414
+ roc_curve_kwargs=curve_metric_kwargs,
415
+ relplot_kwargs=roc_plot_kwargs)
416
+
417
+ self.pr_plot(
418
+ show=False,
419
+ save_path=join_root("pr.jpg"),
420
+ pr_curve_kwargs=curve_metric_kwargs,
421
+ relplot_kwargs=pr_plot_kwargs)
422
+
423
+ self.confusion_matrix(
424
+ show=False, save_path=join_root("confusion_matrix.jpg"), cm_kwargs=curve_metric_kwargs)
425
+
426
+ logger.info(f"Done. Your outputs are saved at {output_folder}")
@@ -0,0 +1,150 @@
1
+ from typing import List, Tuple
2
+ import numpy as np
3
+ from clarifai_grpc.grpc.api import resources_pb2 as respb2
4
+
5
+
6
+ def parse_eval_annotation_classifier(
7
+ eval_metrics: respb2.EvalMetrics) -> Tuple[np.array, np.array, list, List[respb2.Input]]:
8
+ test_set = eval_metrics.test_set
9
+ # get concept ids
10
+ concept_ids = [each.id for each in test_set[0].predicted_concepts]
11
+ concept_ids.sort()
12
+ # get test set
13
+ y_preds = []
14
+ y = []
15
+ inputs = []
16
+ for data in test_set:
17
+
18
+ def _to_array(_data):
19
+ cps = [0] * len(concept_ids)
20
+ for each in _data:
21
+ cps[concept_ids.index(each.id)] = each.value
22
+ return np.asarray(cps)
23
+
24
+ y_preds.append(_to_array(data.predicted_concepts))
25
+ y.append(_to_array(data.ground_truth_concepts))
26
+ inputs.append(data.input)
27
+
28
+ return np.asarray(y), np.asarray(y_preds), concept_ids, inputs
29
+
30
+
31
+ def parse_eval_annotation_detector(eval_metrics: respb2.EvalMetrics,
32
+ normalized_box: bool = False,
33
+ box_style: str = "xyxy"
34
+ ) -> Tuple[np.array, np.array, list, List[respb2.Input]]:
35
+ BOX_STYLES = ["xyxy", "xywh"]
36
+ assert box_style in BOX_STYLES, ValueError(f"Expected box_style in {BOX_STYLES}")
37
+
38
+ concept_ids = list(set([each.concept.id for each in eval_metrics.metrics_by_class]))
39
+ concept_ids.sort()
40
+
41
+ def _get_box_annot(field, img_height, img_width):
42
+ xyxy_concept_score = []
43
+ for each in field:
44
+ box = each.region_info.bounding_box
45
+ x1 = box.left_col * img_width
46
+ y1 = box.top_row * img_height
47
+ x2 = box.right_col * img_width
48
+ y2 = box.bottom_row * img_height
49
+ score = each.data.concepts[0].value
50
+ concept = each.data.concepts[0].id
51
+ concept_index = concept_ids.index(concept)
52
+ if box_style == "xyxy":
53
+ xyxy_concept_score.append([x1, y1, x2, y2, concept_index, score])
54
+ else:
55
+ w = abs(x1 - x2)
56
+ h = abs(y1 - y2)
57
+ xyxy_concept_score.append([x1, y1, w, h, concept_index, score])
58
+
59
+ return np.asarray(xyxy_concept_score)
60
+
61
+ inputs = []
62
+ pred_xyxy_concept_score = []
63
+ gt_xyxy_concept_score = []
64
+ for input_data in eval_metrics.test_set:
65
+ _input = input_data.input
66
+ img_height = _input.data.image.image_info.height if not normalized_box else 1.
67
+ img_width = _input.data.image.image_info.height if not normalized_box else 1.
68
+ _pred_xyxy_concept_score = _get_box_annot(
69
+ input_data.predicted_annotation.data.regions, img_height=img_height, img_width=img_width)
70
+ _gt_xyxy_concept_score = _get_box_annot(
71
+ input_data.ground_truth_annotation.data.regions,
72
+ img_height=img_height,
73
+ img_width=img_width)
74
+
75
+ pred_xyxy_concept_score.append(_pred_xyxy_concept_score)
76
+ gt_xyxy_concept_score.append(_gt_xyxy_concept_score)
77
+ inputs.append(_input)
78
+
79
+ return np.asarray(gt_xyxy_concept_score), np.asarray(
80
+ pred_xyxy_concept_score), concept_ids, inputs
81
+
82
+
83
+ def parse_eval_annotation_detector_coco(eval_metrics: respb2.EvalMetrics,
84
+ ) -> Tuple[np.array, np.array, list, List[respb2.Input]]:
85
+
86
+ gts, preds, concept_ids, inputs = parse_eval_annotation_detector(
87
+ eval_metrics=eval_metrics, normalized_box=False, box_style="xywh")
88
+
89
+ def _make_box_annot(data, input_data, accum_id, is_pred=True):
90
+ img_id = input_data.id
91
+ img_url = input_data.data.image.url
92
+ img_height = input_data.data.image.image_info.height
93
+ img_width = input_data.data.image.image_info.height
94
+ image = {
95
+ "id": img_id,
96
+ "file_name": img_url,
97
+ "width": img_width,
98
+ "height": img_height,
99
+ }
100
+ annotations = []
101
+ for i, d in enumerate(data.tolist()):
102
+ area = d[2] * d[3]
103
+ box = {
104
+ "iscrowd": 0,
105
+ "ignore": 0,
106
+ "image_id": img_id,
107
+ "bbox": d[:4],
108
+ "area": area,
109
+ "segmentation": [],
110
+ "category_id": d[4],
111
+ "id": accum_id + i,
112
+ }
113
+ if is_pred:
114
+ box["score"] = d[5]
115
+ annotations.append(box)
116
+
117
+ return image, annotations
118
+
119
+ categories = [{
120
+ "supercategory": "none",
121
+ "id": i,
122
+ "name": label
123
+ } for i, label in enumerate(concept_ids)]
124
+
125
+ accum_pred_ids = 0
126
+ accum_gt_ids = 0
127
+ pred_images = []
128
+ pred_boxes = []
129
+ gt_images = []
130
+ gt_boxes = []
131
+ for ith, input_proto in enumerate(inputs):
132
+ pred = preds[ith]
133
+ gt = gts[ith]
134
+ pred_img, pred_box = _make_box_annot(
135
+ pred, accum_id=accum_pred_ids, input_data=input_proto, is_pred=True)
136
+ gt_img, gt_box = _make_box_annot(
137
+ gt, accum_id=accum_gt_ids, input_data=input_proto, is_pred=False)
138
+
139
+ accum_pred_ids += len(pred)
140
+ pred_images.append(pred_img)
141
+ pred_boxes += pred_box
142
+
143
+ accum_gt_ids += len(gt)
144
+ gt_images.append(gt_img)
145
+ gt_boxes += gt_box
146
+
147
+ pred_annots = {"images": pred_images, "annotations": pred_boxes, "categories": categories}
148
+ gt_annots = {"images": gt_images, "annotations": gt_boxes, "categories": categories}
149
+
150
+ return gt_annots, pred_annots
clarifai/versions.py CHANGED
@@ -1,6 +1,6 @@
1
1
  import os
2
2
 
3
- CLIENT_VERSION = "10.2.1"
3
+ CLIENT_VERSION = "10.3.0"
4
4
  OS_VER = os.sys.platform
5
5
  PYTHON_VERSION = '.'.join(
6
6
  map(str, [os.sys.version_info.major, os.sys.version_info.minor, os.sys.version_info.micro]))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: clarifai
3
- Version: 10.2.1
3
+ Version: 10.3.0
4
4
  Summary: Clarifai Python SDK
5
5
  Home-page: https://github.com/Clarifai/clarifai-python
6
6
  Author: Clarifai
@@ -20,18 +20,18 @@ Classifier: Operating System :: OS Independent
20
20
  Requires-Python: >=3.8
21
21
  Description-Content-Type: text/markdown
22
22
  License-File: LICENSE
23
- Requires-Dist: clarifai-grpc (~=10.2.1)
24
- Requires-Dist: numpy (>=1.22.0)
25
- Requires-Dist: tqdm (>=4.65.0)
26
- Requires-Dist: tritonclient (>=2.34.0)
27
- Requires-Dist: rich (>=13.4.2)
28
- Requires-Dist: PyYAML (>=6.0.1)
29
- Requires-Dist: schema (>=0.7.5)
30
- Requires-Dist: Pillow (>=9.5.0)
31
- Requires-Dist: inquirerpy (==0.3.4)
32
- Requires-Dist: tabulate (>=0.9.0)
23
+ Requires-Dist: clarifai-grpc ~=10.2.3
24
+ Requires-Dist: numpy >=1.22.0
25
+ Requires-Dist: tqdm >=4.65.0
26
+ Requires-Dist: tritonclient >=2.34.0
27
+ Requires-Dist: rich >=13.4.2
28
+ Requires-Dist: PyYAML >=6.0.1
29
+ Requires-Dist: schema >=0.7.5
30
+ Requires-Dist: Pillow >=9.5.0
31
+ Requires-Dist: inquirerpy ==0.3.4
32
+ Requires-Dist: tabulate >=0.9.0
33
33
  Provides-Extra: all
34
- Requires-Dist: pycocotools (==2.0.6) ; extra == 'all'
34
+ Requires-Dist: pycocotools ==2.0.6 ; extra == 'all'
35
35
 
36
36
  <h1 align="center">
37
37
  <a href="https://www.clarifai.com/"><img alt="Clarifai" title="Clarifai" src="https://upload.wikimedia.org/wikipedia/commons/b/bc/Clarifai_Logo_FC_Web.png"></a>