supervisely 6.73.238__py3-none-any.whl → 6.73.239__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. supervisely/annotation/annotation.py +2 -2
  2. supervisely/api/entity_annotation/tag_api.py +11 -4
  3. supervisely/nn/__init__.py +1 -0
  4. supervisely/nn/benchmark/__init__.py +14 -2
  5. supervisely/nn/benchmark/base_benchmark.py +84 -37
  6. supervisely/nn/benchmark/base_evaluator.py +120 -0
  7. supervisely/nn/benchmark/base_visualizer.py +265 -0
  8. supervisely/nn/benchmark/comparison/detection_visualization/text_templates.py +5 -5
  9. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/calibration_score.py +2 -2
  10. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/explore_predicttions.py +39 -16
  11. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/localization_accuracy.py +1 -1
  12. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/outcome_counts.py +4 -4
  13. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/overview.py +12 -11
  14. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/pr_curve.py +1 -1
  15. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/precision_recal_f1.py +6 -6
  16. supervisely/nn/benchmark/comparison/detection_visualization/vis_metrics/speedtest.py +3 -3
  17. supervisely/nn/benchmark/{instance_segmentation_benchmark.py → instance_segmentation/benchmark.py} +9 -3
  18. supervisely/nn/benchmark/instance_segmentation/evaluator.py +58 -0
  19. supervisely/nn/benchmark/{visualization/text_templates/instance_segmentation_text.py → instance_segmentation/text_templates.py} +53 -69
  20. supervisely/nn/benchmark/instance_segmentation/visualizer.py +18 -0
  21. supervisely/nn/benchmark/object_detection/__init__.py +0 -0
  22. supervisely/nn/benchmark/object_detection/base_vis_metric.py +51 -0
  23. supervisely/nn/benchmark/{object_detection_benchmark.py → object_detection/benchmark.py} +4 -2
  24. supervisely/nn/benchmark/object_detection/evaluation_params.yaml +2 -0
  25. supervisely/nn/benchmark/{evaluation/object_detection_evaluator.py → object_detection/evaluator.py} +67 -9
  26. supervisely/nn/benchmark/{evaluation/coco → object_detection}/metric_provider.py +13 -14
  27. supervisely/nn/benchmark/{visualization/text_templates/object_detection_text.py → object_detection/text_templates.py} +49 -41
  28. supervisely/nn/benchmark/object_detection/vis_metrics/__init__.py +48 -0
  29. supervisely/nn/benchmark/{visualization → object_detection}/vis_metrics/confidence_distribution.py +20 -24
  30. supervisely/nn/benchmark/object_detection/vis_metrics/confidence_score.py +119 -0
  31. supervisely/nn/benchmark/{visualization → object_detection}/vis_metrics/confusion_matrix.py +34 -22
  32. supervisely/nn/benchmark/object_detection/vis_metrics/explore_predictions.py +129 -0
  33. supervisely/nn/benchmark/{visualization → object_detection}/vis_metrics/f1_score_at_different_iou.py +21 -26
  34. supervisely/nn/benchmark/object_detection/vis_metrics/frequently_confused.py +137 -0
  35. supervisely/nn/benchmark/object_detection/vis_metrics/iou_distribution.py +106 -0
  36. supervisely/nn/benchmark/object_detection/vis_metrics/key_metrics.py +136 -0
  37. supervisely/nn/benchmark/{visualization → object_detection}/vis_metrics/model_predictions.py +53 -49
  38. supervisely/nn/benchmark/object_detection/vis_metrics/outcome_counts.py +188 -0
  39. supervisely/nn/benchmark/object_detection/vis_metrics/outcome_counts_per_class.py +191 -0
  40. supervisely/nn/benchmark/object_detection/vis_metrics/overview.py +116 -0
  41. supervisely/nn/benchmark/object_detection/vis_metrics/pr_curve.py +106 -0
  42. supervisely/nn/benchmark/object_detection/vis_metrics/pr_curve_by_class.py +49 -0
  43. supervisely/nn/benchmark/object_detection/vis_metrics/precision.py +72 -0
  44. supervisely/nn/benchmark/object_detection/vis_metrics/precision_avg_per_class.py +59 -0
  45. supervisely/nn/benchmark/object_detection/vis_metrics/recall.py +71 -0
  46. supervisely/nn/benchmark/object_detection/vis_metrics/recall_vs_precision.py +56 -0
  47. supervisely/nn/benchmark/object_detection/vis_metrics/reliability_diagram.py +110 -0
  48. supervisely/nn/benchmark/object_detection/vis_metrics/speedtest.py +151 -0
  49. supervisely/nn/benchmark/object_detection/visualizer.py +697 -0
  50. supervisely/nn/benchmark/semantic_segmentation/__init__.py +9 -0
  51. supervisely/nn/benchmark/semantic_segmentation/base_vis_metric.py +55 -0
  52. supervisely/nn/benchmark/semantic_segmentation/benchmark.py +32 -0
  53. supervisely/nn/benchmark/semantic_segmentation/evaluation_params.yaml +0 -0
  54. supervisely/nn/benchmark/semantic_segmentation/evaluator.py +162 -0
  55. supervisely/nn/benchmark/semantic_segmentation/metric_provider.py +153 -0
  56. supervisely/nn/benchmark/semantic_segmentation/text_templates.py +130 -0
  57. supervisely/nn/benchmark/semantic_segmentation/vis_metrics/__init__.py +0 -0
  58. supervisely/nn/benchmark/semantic_segmentation/vis_metrics/acknowledgement.py +15 -0
  59. supervisely/nn/benchmark/semantic_segmentation/vis_metrics/classwise_error_analysis.py +57 -0
  60. supervisely/nn/benchmark/semantic_segmentation/vis_metrics/confusion_matrix.py +92 -0
  61. supervisely/nn/benchmark/semantic_segmentation/vis_metrics/explore_predictions.py +84 -0
  62. supervisely/nn/benchmark/semantic_segmentation/vis_metrics/frequently_confused.py +101 -0
  63. supervisely/nn/benchmark/semantic_segmentation/vis_metrics/iou_eou.py +45 -0
  64. supervisely/nn/benchmark/semantic_segmentation/vis_metrics/key_metrics.py +60 -0
  65. supervisely/nn/benchmark/semantic_segmentation/vis_metrics/model_predictions.py +107 -0
  66. supervisely/nn/benchmark/semantic_segmentation/vis_metrics/overview.py +112 -0
  67. supervisely/nn/benchmark/semantic_segmentation/vis_metrics/renormalized_error_ou.py +48 -0
  68. supervisely/nn/benchmark/semantic_segmentation/vis_metrics/speedtest.py +178 -0
  69. supervisely/nn/benchmark/semantic_segmentation/vis_metrics/vis_texts.py +21 -0
  70. supervisely/nn/benchmark/semantic_segmentation/visualizer.py +304 -0
  71. supervisely/nn/benchmark/utils/__init__.py +12 -0
  72. supervisely/nn/benchmark/utils/detection/__init__.py +2 -0
  73. supervisely/nn/benchmark/{evaluation/coco → utils/detection}/calculate_metrics.py +6 -4
  74. supervisely/nn/benchmark/utils/detection/metric_provider.py +533 -0
  75. supervisely/nn/benchmark/{coco_utils → utils/detection}/sly2coco.py +4 -4
  76. supervisely/nn/benchmark/{coco_utils/utils.py → utils/detection/utlis.py} +11 -0
  77. supervisely/nn/benchmark/utils/semantic_segmentation/__init__.py +0 -0
  78. supervisely/nn/benchmark/utils/semantic_segmentation/calculate_metrics.py +35 -0
  79. supervisely/nn/benchmark/utils/semantic_segmentation/evaluator.py +804 -0
  80. supervisely/nn/benchmark/utils/semantic_segmentation/loader.py +65 -0
  81. supervisely/nn/benchmark/utils/semantic_segmentation/utils.py +109 -0
  82. supervisely/nn/benchmark/visualization/evaluation_result.py +17 -3
  83. supervisely/nn/benchmark/visualization/vis_click_data.py +1 -1
  84. supervisely/nn/benchmark/visualization/widgets/__init__.py +3 -0
  85. supervisely/nn/benchmark/visualization/widgets/chart/chart.py +12 -4
  86. supervisely/nn/benchmark/visualization/widgets/gallery/gallery.py +35 -8
  87. supervisely/nn/benchmark/visualization/widgets/gallery/template.html +8 -4
  88. supervisely/nn/benchmark/visualization/widgets/markdown/markdown.py +1 -1
  89. supervisely/nn/benchmark/visualization/widgets/notification/notification.py +11 -7
  90. supervisely/nn/benchmark/visualization/widgets/radio_group/__init__.py +0 -0
  91. supervisely/nn/benchmark/visualization/widgets/radio_group/radio_group.py +34 -0
  92. supervisely/nn/benchmark/visualization/widgets/table/table.py +9 -3
  93. supervisely/nn/benchmark/visualization/widgets/widget.py +4 -0
  94. supervisely/project/project.py +18 -6
  95. {supervisely-6.73.238.dist-info → supervisely-6.73.239.dist-info}/METADATA +3 -1
  96. {supervisely-6.73.238.dist-info → supervisely-6.73.239.dist-info}/RECORD +103 -81
  97. supervisely/nn/benchmark/coco_utils/__init__.py +0 -2
  98. supervisely/nn/benchmark/evaluation/__init__.py +0 -3
  99. supervisely/nn/benchmark/evaluation/base_evaluator.py +0 -64
  100. supervisely/nn/benchmark/evaluation/coco/__init__.py +0 -2
  101. supervisely/nn/benchmark/evaluation/instance_segmentation_evaluator.py +0 -88
  102. supervisely/nn/benchmark/utils.py +0 -13
  103. supervisely/nn/benchmark/visualization/inference_speed/__init__.py +0 -19
  104. supervisely/nn/benchmark/visualization/inference_speed/speedtest_batch.py +0 -161
  105. supervisely/nn/benchmark/visualization/inference_speed/speedtest_intro.py +0 -28
  106. supervisely/nn/benchmark/visualization/inference_speed/speedtest_overview.py +0 -141
  107. supervisely/nn/benchmark/visualization/inference_speed/speedtest_real_time.py +0 -63
  108. supervisely/nn/benchmark/visualization/text_templates/inference_speed_text.py +0 -23
  109. supervisely/nn/benchmark/visualization/vis_metric_base.py +0 -337
  110. supervisely/nn/benchmark/visualization/vis_metrics/__init__.py +0 -67
  111. supervisely/nn/benchmark/visualization/vis_metrics/classwise_error_analysis.py +0 -55
  112. supervisely/nn/benchmark/visualization/vis_metrics/confidence_score.py +0 -93
  113. supervisely/nn/benchmark/visualization/vis_metrics/explorer_grid.py +0 -144
  114. supervisely/nn/benchmark/visualization/vis_metrics/frequently_confused.py +0 -115
  115. supervisely/nn/benchmark/visualization/vis_metrics/iou_distribution.py +0 -86
  116. supervisely/nn/benchmark/visualization/vis_metrics/outcome_counts.py +0 -119
  117. supervisely/nn/benchmark/visualization/vis_metrics/outcome_counts_per_class.py +0 -148
  118. supervisely/nn/benchmark/visualization/vis_metrics/overall_error_analysis.py +0 -109
  119. supervisely/nn/benchmark/visualization/vis_metrics/overview.py +0 -189
  120. supervisely/nn/benchmark/visualization/vis_metrics/percision_avg_per_class.py +0 -57
  121. supervisely/nn/benchmark/visualization/vis_metrics/pr_curve.py +0 -101
  122. supervisely/nn/benchmark/visualization/vis_metrics/pr_curve_by_class.py +0 -46
  123. supervisely/nn/benchmark/visualization/vis_metrics/precision.py +0 -56
  124. supervisely/nn/benchmark/visualization/vis_metrics/recall.py +0 -54
  125. supervisely/nn/benchmark/visualization/vis_metrics/recall_vs_precision.py +0 -57
  126. supervisely/nn/benchmark/visualization/vis_metrics/reliability_diagram.py +0 -88
  127. supervisely/nn/benchmark/visualization/vis_metrics/what_is.py +0 -23
  128. supervisely/nn/benchmark/visualization/vis_templates.py +0 -241
  129. supervisely/nn/benchmark/visualization/vis_widgets.py +0 -128
  130. supervisely/nn/benchmark/visualization/visualizer.py +0 -729
  131. /supervisely/nn/benchmark/{visualization/text_templates → instance_segmentation}/__init__.py +0 -0
  132. /supervisely/nn/benchmark/{evaluation/coco → instance_segmentation}/evaluation_params.yaml +0 -0
  133. /supervisely/nn/benchmark/{evaluation/coco → utils/detection}/metrics.py +0 -0
  134. {supervisely-6.73.238.dist-info → supervisely-6.73.239.dist-info}/LICENSE +0 -0
  135. {supervisely-6.73.238.dist-info → supervisely-6.73.239.dist-info}/WHEEL +0 -0
  136. {supervisely-6.73.238.dist-info → supervisely-6.73.239.dist-info}/entry_points.txt +0 -0
  137. {supervisely-6.73.238.dist-info → supervisely-6.73.239.dist-info}/top_level.txt +0 -0
@@ -2177,7 +2177,7 @@ class Annotation:
2177
2177
 
2178
2178
  if np.any(mask): # figure may be entirely covered by others
2179
2179
  g = lbl.geometry
2180
- new_mask = Bitmap(data=mask)
2180
+ new_mask = Bitmap(data=mask, extra_validation=False)
2181
2181
  new_lbl = lbl.clone(geometry=new_mask, obj_class=dest_class)
2182
2182
  new_labels.append(new_lbl)
2183
2183
  new_ann = self.clone(labels=new_labels)
@@ -2366,7 +2366,7 @@ class Annotation:
2366
2366
  new_labels = []
2367
2367
  for obj_class, white_mask in class_mask.items():
2368
2368
  mask = white_mask == 255
2369
- bitmap = Bitmap(data=mask)
2369
+ bitmap = Bitmap(data=mask, extra_validation=False)
2370
2370
  new_labels.append(Label(geometry=bitmap, obj_class=obj_class))
2371
2371
  return self.clone(labels=new_labels)
2372
2372
 
@@ -1,12 +1,13 @@
1
1
  # coding: utf-8
2
2
 
3
- from typing import List
3
+ from typing import List, Optional
4
4
 
5
5
  from supervisely._utils import batched
6
6
  from supervisely.api.module_api import ApiField, ModuleApi
7
7
  from supervisely.collection.key_indexed_collection import KeyIndexedCollection
8
8
  from supervisely.task.progress import tqdm_sly
9
9
  from supervisely.video_annotation.key_id_map import KeyIdMap
10
+ from supervisely.task.progress import tqdm_sly
10
11
 
11
12
 
12
13
  class TagApi(ModuleApi):
@@ -241,6 +242,7 @@ class TagApi(ModuleApi):
241
242
  tags_list: List[dict],
242
243
  batch_size: int = 100,
243
244
  log_progress: bool = False,
245
+ progress: Optional[tqdm_sly] = None,
244
246
  ) -> List[dict]:
245
247
  """
246
248
  Add Tags to existing Annotation Figures.
@@ -255,6 +257,8 @@ class TagApi(ModuleApi):
255
257
  :type batch_size: int
256
258
  :param log_progress: If True, will display a progress bar.
257
259
  :type log_progress: bool
260
+ :param progress: Progress bar object to display progress.
261
+ :type progress: Optional[tqdm_sly]
258
262
  :return: List of tags infos as dictionaries.
259
263
  :rtype: List[dict]
260
264
 
@@ -313,9 +317,12 @@ class TagApi(ModuleApi):
313
317
  if len(tags_list) == 0:
314
318
  return []
315
319
 
320
+ if progress is not None:
321
+ log_progress = False
322
+
316
323
  result = []
317
324
  if log_progress:
318
- ds_progress = tqdm_sly(
325
+ progress = tqdm_sly(
319
326
  desc="Adding tags to figures",
320
327
  total=len(tags_list),
321
328
  )
@@ -323,6 +330,6 @@ class TagApi(ModuleApi):
323
330
  data = {ApiField.PROJECT_ID: project_id, ApiField.TAGS: batch}
324
331
  response = self._api.post("figures.tags.bulk.add", data)
325
332
  result.extend(response.json())
326
- if log_progress:
327
- ds_progress.update(len(batch))
333
+ if progress is not None:
334
+ progress.update(len(batch))
328
335
  return result
@@ -1,4 +1,5 @@
1
1
  import supervisely.nn.artifacts as artifacts
2
+ import supervisely.nn.benchmark as benchmark
2
3
  import supervisely.nn.inference as inference
3
4
  from supervisely.nn.artifacts.artifacts import BaseTrainArtifacts, TrainInfo
4
5
  from supervisely.nn.prediction_dto import (
@@ -1,2 +1,14 @@
1
- from supervisely.nn.benchmark.object_detection_benchmark import ObjectDetectionBenchmark
2
- from supervisely.nn.benchmark.instance_segmentation_benchmark import InstanceSegmentationBenchmark
1
+ from supervisely.nn.benchmark.instance_segmentation.benchmark import (
2
+ InstanceSegmentationBenchmark,
3
+ )
4
+ from supervisely.nn.benchmark.instance_segmentation.evaluator import (
5
+ InstanceSegmentationEvaluator,
6
+ )
7
+ from supervisely.nn.benchmark.object_detection.benchmark import ObjectDetectionBenchmark
8
+ from supervisely.nn.benchmark.object_detection.evaluator import ObjectDetectionEvaluator
9
+ from supervisely.nn.benchmark.semantic_segmentation.benchmark import (
10
+ SemanticSegmentationBenchmark,
11
+ )
12
+ from supervisely.nn.benchmark.semantic_segmentation.evaluator import (
13
+ SemanticSegmentationEvaluator,
14
+ )
@@ -1,26 +1,31 @@
1
1
  import os
2
- import yaml
3
2
  from typing import Callable, List, Optional, Tuple, Union
4
3
 
5
4
  import numpy as np
6
5
 
7
6
  from supervisely._utils import is_development
8
7
  from supervisely.api.api import Api
8
+ from supervisely.api.module_api import ApiField
9
9
  from supervisely.api.project_api import ProjectInfo
10
10
  from supervisely.app.widgets import SlyTqdm
11
11
  from supervisely.io import env, fs, json
12
12
  from supervisely.io.fs import get_directory_size
13
- from supervisely.nn.benchmark.evaluation import BaseEvaluator
14
- from supervisely.nn.benchmark.utils import WORKSPACE_DESCRIPTION, WORKSPACE_NAME
15
- from supervisely.nn.benchmark.visualization.visualizer import Visualizer
13
+ from supervisely.nn.benchmark.base_evaluator import BaseEvaluator
16
14
  from supervisely.nn.inference import SessionJSON
17
15
  from supervisely.project.project import download_project
18
16
  from supervisely.project.project_meta import ProjectMeta
19
17
  from supervisely.sly_logger import logger
20
18
  from supervisely.task.progress import tqdm_sly
21
19
 
20
+ WORKSPACE_NAME = "Model Benchmark: predictions and differences"
21
+ WORKSPACE_DESCRIPTION = "Technical workspace for model benchmarking. Contains predictions and differences between ground truth and predictions."
22
+
22
23
 
23
24
  class BaseBenchmark:
25
+ visualizer_cls = None
26
+ EVALUATION_DIR_NAME = "evaluation"
27
+ SPEEDTEST_DIR_NAME = "speedtest"
28
+ VISUALIZATIONS_DIR_NAME = "visualizations"
24
29
 
25
30
  def __init__(
26
31
  self,
@@ -41,6 +46,20 @@ class BaseBenchmark:
41
46
  self.diff_project_info: ProjectInfo = None
42
47
  self.gt_dataset_ids = gt_dataset_ids
43
48
  self.gt_images_ids = gt_images_ids
49
+ self.gt_dataset_infos = None
50
+ if gt_dataset_ids is not None:
51
+ self.gt_dataset_infos = self.api.dataset.get_list(
52
+ self.gt_project_info.id,
53
+ filters=[
54
+ {
55
+ ApiField.FIELD: ApiField.ID,
56
+ ApiField.OPERATOR: "in",
57
+ ApiField.VALUE: gt_dataset_ids,
58
+ }
59
+ ],
60
+ recursive=True,
61
+ )
62
+ self.num_items = self._get_total_items_for_progress()
44
63
  self.output_dir = output_dir
45
64
  self.team_id = env.team_id()
46
65
  self.evaluator: BaseEvaluator = None
@@ -53,6 +72,7 @@ class BaseBenchmark:
53
72
  self.vis_texts = None
54
73
  self.inference_speed_text = None
55
74
  self.train_info = None
75
+ self.evaluator_app_info = None
56
76
  self.evaluation_params = evaluation_params
57
77
  self._validate_evaluation_params()
58
78
 
@@ -75,9 +95,12 @@ class BaseBenchmark:
75
95
  batch_size: int = 16,
76
96
  cache_project_on_agent: bool = False,
77
97
  ):
78
- self.session = self._init_model_session(model_session, inference_settings)
79
- self._eval_inference_info = self._run_inference(
80
- output_project_id, batch_size, cache_project_on_agent
98
+ self.run_inference(
99
+ model_session=model_session,
100
+ inference_settings=inference_settings,
101
+ output_project_id=output_project_id,
102
+ batch_size=batch_size,
103
+ cache_project_on_agent=cache_project_on_agent,
81
104
  )
82
105
  self.evaluate(self.dt_project_info.id)
83
106
  self._dump_eval_inference_info(self._eval_inference_info)
@@ -103,6 +126,11 @@ class BaseBenchmark:
103
126
  ):
104
127
  model_info = self._fetch_model_info(self.session)
105
128
  self.dt_project_info = self._get_or_create_dt_project(output_project_id, model_info)
129
+ logger.info(
130
+ f"""
131
+ Predictions project ID: {self.dt_project_info.id},
132
+ workspace ID: {self.dt_project_info.workspace_id}"""
133
+ )
106
134
  if self.gt_images_ids is None:
107
135
  iterator = self.session.inference_project_id_async(
108
136
  self.gt_project_info.id,
@@ -118,9 +146,7 @@ class BaseBenchmark:
118
146
  batch_size=batch_size,
119
147
  )
120
148
  output_project_id = self.dt_project_info.id
121
- with self.pbar(
122
- message="Evaluation: Running inference", total=self.gt_project_info.items_count
123
- ) as p:
149
+ with self.pbar(message="Evaluation: Running inference", total=self.num_items) as p:
124
150
  for _ in iterator:
125
151
  p.update(1)
126
152
  inference_info = {
@@ -130,6 +156,14 @@ class BaseBenchmark:
130
156
  "batch_size": batch_size,
131
157
  **model_info,
132
158
  }
159
+ if self.train_info:
160
+ self.train_info.pop("train_images_ids", None)
161
+ inference_info["train_info"] = self.train_info
162
+ if self.evaluator_app_info:
163
+ self.evaluator_app_info.pop("settings", None)
164
+ inference_info["evaluator_app_info"] = self.evaluator_app_info
165
+ if self.gt_images_ids:
166
+ inference_info["val_images_cnt"] = len(self.gt_images_ids)
133
167
  self.dt_project_info = self.api.project.get_info_by_id(self.dt_project_info.id)
134
168
  logger.debug(
135
169
  "Inference is finished.",
@@ -151,10 +185,10 @@ class BaseBenchmark:
151
185
  eval_results_dir = self.get_eval_results_dir()
152
186
  self.evaluator = self._get_evaluator_class()(
153
187
  gt_project_path=gt_project_path,
154
- dt_project_path=dt_project_path,
188
+ pred_project_path=dt_project_path,
155
189
  result_dir=eval_results_dir,
156
190
  progress=self.pbar,
157
- items_count=self.dt_project_info.items_count,
191
+ items_count=self.num_items,
158
192
  classes_whitelist=self.classes_whitelist,
159
193
  evaluation_params=self.evaluation_params,
160
194
  )
@@ -244,19 +278,18 @@ class BaseBenchmark:
244
278
  def get_project_paths(self):
245
279
  base_dir = self.get_base_dir()
246
280
  gt_path = os.path.join(base_dir, "gt_project")
247
- dt_path = os.path.join(base_dir, "dt_project")
281
+ dt_path = os.path.join(base_dir, "pred_project")
248
282
  return gt_path, dt_path
249
283
 
250
284
  def get_eval_results_dir(self) -> str:
251
- dir = os.path.join(self.get_base_dir(), "evaluation")
252
- os.makedirs(dir, exist_ok=True)
253
- return dir
285
+ eval_dir = os.path.join(self.get_base_dir(), self.EVALUATION_DIR_NAME)
286
+ os.makedirs(eval_dir, exist_ok=True)
287
+ return eval_dir
254
288
 
255
289
  def get_speedtest_results_dir(self) -> str:
256
- checkpoint_name = self._speedtest["model_info"]["checkpoint_name"]
257
- dir = os.path.join(self.output_dir, "speedtest", checkpoint_name)
258
- os.makedirs(dir, exist_ok=True)
259
- return dir
290
+ speedtest_dir = os.path.join(self.get_base_dir(), self.SPEEDTEST_DIR_NAME)
291
+ os.makedirs(speedtest_dir, exist_ok=True)
292
+ return speedtest_dir
260
293
 
261
294
  def upload_eval_results(self, remote_dir: str):
262
295
  eval_dir = self.get_eval_results_dir()
@@ -279,7 +312,7 @@ class BaseBenchmark:
279
312
  )
280
313
 
281
314
  def get_layout_results_dir(self) -> str:
282
- dir = os.path.join(self.get_base_dir(), "layout")
315
+ dir = os.path.join(self.get_base_dir(), self.VISUALIZATIONS_DIR_NAME)
283
316
  os.makedirs(dir, exist_ok=True)
284
317
  return dir
285
318
 
@@ -321,12 +354,9 @@ class BaseBenchmark:
321
354
  def _download_projects(self, save_images=False):
322
355
  gt_path, dt_path = self.get_project_paths()
323
356
  if not os.path.exists(gt_path):
324
- total = (
325
- self.gt_project_info.items_count
326
- if self.gt_images_ids is None
327
- else len(self.gt_images_ids)
328
- )
329
- with self.pbar(message="Evaluation: Downloading GT annotations", total=total) as p:
357
+ with self.pbar(
358
+ message="Evaluation: Downloading GT annotations", total=self.num_items
359
+ ) as p:
330
360
  download_project(
331
361
  self.api,
332
362
  self.gt_project_info.id,
@@ -341,12 +371,9 @@ class BaseBenchmark:
341
371
  else:
342
372
  logger.info(f"Found GT annotations in {gt_path}")
343
373
  if not os.path.exists(dt_path):
344
- total = (
345
- self.gt_project_info.items_count
346
- if self.gt_images_ids is None
347
- else len(self.gt_images_ids)
348
- )
349
- with self.pbar(message="Evaluation: Downloading Pred annotations", total=total) as p:
374
+ with self.pbar(
375
+ message="Evaluation: Downloading Pred annotations", total=self.num_items
376
+ ) as p:
350
377
  download_project(
351
378
  self.api,
352
379
  self.dt_project_info.id,
@@ -374,7 +401,7 @@ class BaseBenchmark:
374
401
  "id": app_info["id"],
375
402
  }
376
403
  else:
377
- logger.warn("session.task_id is not set. App info will not be fetched.")
404
+ logger.warning("session.task_id is not set. App info will not be fetched.")
378
405
  app_info = None
379
406
  model_info = {
380
407
  **deploy_info,
@@ -450,8 +477,16 @@ class BaseBenchmark:
450
477
  if dt_project_id is not None:
451
478
  self.dt_project_info = self.api.project.get_info_by_id(dt_project_id)
452
479
 
453
- vis = Visualizer(self)
454
- vis.visualize()
480
+ if self.visualizer_cls is None:
481
+ raise RuntimeError(
482
+ f"Visualizer class is not defined in {self.__class__.__name__}. "
483
+ "It should be defined in the subclass of BaseBenchmark (e.g. ObjectDetectionBenchmark)."
484
+ )
485
+ eval_result = self.evaluator.get_eval_result()
486
+ vis = self.visualizer_cls(self.api, [eval_result], self.get_layout_results_dir(), self.pbar) # pylint: disable=not-callable
487
+ with self.pbar(message="Visualizations: Rendering layout", total=1) as p:
488
+ vis.visualize()
489
+ p.update(1)
455
490
 
456
491
  def _get_or_create_diff_project(self) -> Tuple[ProjectInfo, bool]:
457
492
 
@@ -499,7 +534,7 @@ class BaseBenchmark:
499
534
  layout_dir = self.get_layout_results_dir()
500
535
  assert not fs.dir_empty(
501
536
  layout_dir
502
- ), f"The layout dir {layout_dir!r} is empty. You should run evaluation before uploading results."
537
+ ), f"The layout dir {layout_dir!r} is empty. You should run visualizations before uploading results."
503
538
 
504
539
  # self.api.file.remove_dir(self.team_id, dest_dir, silent=True)
505
540
 
@@ -555,9 +590,21 @@ class BaseBenchmark:
555
590
  if not pred_meta.obj_classes.has_key(obj_cls.name):
556
591
  pred_meta = pred_meta.add_obj_class(obj_cls)
557
592
  chagned = True
593
+ for tag_meta in gt_meta.tag_metas:
594
+ if not pred_meta.tag_metas.has_key(tag_meta.name):
595
+ pred_meta = pred_meta.add_tag_meta(tag_meta)
596
+ chagned = True
558
597
  if chagned:
559
598
  self.api.project.update_meta(pred_project_id, pred_meta.to_json())
560
599
 
561
600
  def _validate_evaluation_params(self):
562
601
  if self.evaluation_params:
563
602
  self._get_evaluator_class().validate_evaluation_params(self.evaluation_params)
603
+
604
+ def _get_total_items_for_progress(self):
605
+ if self.gt_images_ids is not None:
606
+ return len(self.gt_images_ids)
607
+ elif self.gt_dataset_ids is not None:
608
+ return sum(ds.items_count for ds in self.gt_dataset_infos)
609
+ else:
610
+ return self.gt_project_info.items_count
@@ -0,0 +1,120 @@
1
+ import os
2
+ import pickle
3
+ from typing import Dict, List, Optional, Union
4
+
5
+ import yaml
6
+
7
+ from supervisely.app.widgets import SlyTqdm
8
+ from supervisely.task.progress import tqdm_sly
9
+
10
+
11
+ class BaseEvalResult:
12
+ def __init__(self, directory: str):
13
+ self.directory = directory
14
+ self.inference_info: Dict = None
15
+ self.speedtest_info: Dict = None
16
+ self.eval_data: Dict = None
17
+ self.mp = None
18
+
19
+ self._read_eval_data()
20
+
21
+ @property
22
+ def cv_task(self):
23
+ return self.inference_info.get("task_type")
24
+
25
+ @property
26
+ def name(self) -> str:
27
+ model_name = self.inference_info.get("model_name", self.directory)
28
+ return self.inference_info.get("deploy_params", {}).get("checkpoint_name", model_name)
29
+
30
+ @property
31
+ def gt_project_id(self) -> int:
32
+ return self.inference_info.get("gt_project_id")
33
+
34
+ @property
35
+ def gt_dataset_ids(self) -> List[int]:
36
+ return self.inference_info.get("gt_dataset_ids", None)
37
+
38
+ @property
39
+ def dt_project_id(self):
40
+ return self.inference_info.get("dt_project_id")
41
+
42
+ @property
43
+ def pred_project_id(self):
44
+ return self.dt_project_id
45
+
46
+ @property
47
+ def train_info(self):
48
+ return self.inference_info.get("train_info", None) # TODO: check
49
+
50
+ @property
51
+ def evaluator_app_info(self):
52
+ return self.inference_info.get("evaluator_app_info", None) # TODO: check
53
+
54
+ @property
55
+ def val_images_cnt(self):
56
+ return self.inference_info.get("val_images_cnt", None) # TODO: check
57
+
58
+ @property
59
+ def classes_whitelist(self):
60
+ return self.inference_info.get("inference_settings", {}).get("classes", []) # TODO: check
61
+
62
+ def _read_eval_data(self):
63
+ raise NotImplementedError()
64
+
65
+
66
+ class BaseEvaluator:
67
+ EVALUATION_PARAMS_YAML_PATH: str = None
68
+ eval_result_cls = BaseEvalResult
69
+
70
+ def __init__(
71
+ self,
72
+ gt_project_path: str,
73
+ pred_project_path: str,
74
+ result_dir: str = "./evaluation",
75
+ progress: Optional[SlyTqdm] = None,
76
+ items_count: Optional[int] = None, # TODO: is it needed?
77
+ classes_whitelist: Optional[List[str]] = None,
78
+ evaluation_params: Optional[dict] = None,
79
+ ):
80
+ self.gt_project_path = gt_project_path
81
+ self.pred_project_path = pred_project_path
82
+ self.result_dir = result_dir
83
+ self.total_items = items_count
84
+ self.pbar = progress or tqdm_sly
85
+ os.makedirs(result_dir, exist_ok=True)
86
+ self.classes_whitelist = classes_whitelist
87
+
88
+ if evaluation_params is None:
89
+ evaluation_params = self._get_default_evaluation_params()
90
+ self.evaluation_params = evaluation_params
91
+ if self.evaluation_params:
92
+ self.validate_evaluation_params(self.evaluation_params)
93
+
94
+ def evaluate(self):
95
+ raise NotImplementedError()
96
+
97
+ @classmethod
98
+ def load_yaml_evaluation_params(cls) -> Union[str, None]:
99
+ if cls.EVALUATION_PARAMS_YAML_PATH is None:
100
+ return None
101
+ with open(cls.EVALUATION_PARAMS_YAML_PATH, "r") as f:
102
+ return f.read()
103
+
104
+ @classmethod
105
+ def validate_evaluation_params(cls, evaluation_params: dict) -> None:
106
+ pass
107
+
108
+ @classmethod
109
+ def _get_default_evaluation_params(cls) -> dict:
110
+ if cls.EVALUATION_PARAMS_YAML_PATH is None:
111
+ return {}
112
+ else:
113
+ return yaml.safe_load(cls.load_yaml_evaluation_params())
114
+
115
+ def _dump_pickle(self, data, file_path):
116
+ with open(file_path, "wb") as f:
117
+ pickle.dump(data, f)
118
+
119
+ def get_eval_result(self) -> BaseEvalResult:
120
+ return self.eval_result_cls(self.result_dir)