clarifai 11.3.0rc2__py3-none-any.whl → 11.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (300) hide show
  1. clarifai/__init__.py +1 -1
  2. clarifai/cli/__main__.py +1 -1
  3. clarifai/cli/base.py +144 -136
  4. clarifai/cli/compute_cluster.py +45 -31
  5. clarifai/cli/deployment.py +93 -76
  6. clarifai/cli/model.py +578 -180
  7. clarifai/cli/nodepool.py +100 -82
  8. clarifai/client/__init__.py +12 -2
  9. clarifai/client/app.py +973 -911
  10. clarifai/client/auth/helper.py +345 -342
  11. clarifai/client/auth/register.py +7 -7
  12. clarifai/client/auth/stub.py +107 -106
  13. clarifai/client/base.py +185 -178
  14. clarifai/client/compute_cluster.py +214 -180
  15. clarifai/client/dataset.py +793 -698
  16. clarifai/client/deployment.py +55 -50
  17. clarifai/client/input.py +1223 -1088
  18. clarifai/client/lister.py +47 -45
  19. clarifai/client/model.py +1939 -1717
  20. clarifai/client/model_client.py +525 -502
  21. clarifai/client/module.py +82 -73
  22. clarifai/client/nodepool.py +358 -213
  23. clarifai/client/runner.py +58 -0
  24. clarifai/client/search.py +342 -309
  25. clarifai/client/user.py +419 -414
  26. clarifai/client/workflow.py +294 -274
  27. clarifai/constants/dataset.py +11 -17
  28. clarifai/constants/model.py +8 -2
  29. clarifai/datasets/export/inputs_annotations.py +233 -217
  30. clarifai/datasets/upload/base.py +63 -51
  31. clarifai/datasets/upload/features.py +43 -38
  32. clarifai/datasets/upload/image.py +237 -207
  33. clarifai/datasets/upload/loaders/coco_captions.py +34 -32
  34. clarifai/datasets/upload/loaders/coco_detection.py +72 -65
  35. clarifai/datasets/upload/loaders/imagenet_classification.py +57 -53
  36. clarifai/datasets/upload/loaders/xview_detection.py +274 -132
  37. clarifai/datasets/upload/multimodal.py +55 -46
  38. clarifai/datasets/upload/text.py +55 -47
  39. clarifai/datasets/upload/utils.py +250 -234
  40. clarifai/errors.py +51 -50
  41. clarifai/models/api.py +260 -238
  42. clarifai/modules/css.py +50 -50
  43. clarifai/modules/pages.py +33 -33
  44. clarifai/rag/rag.py +312 -288
  45. clarifai/rag/utils.py +91 -84
  46. clarifai/runners/models/model_builder.py +906 -802
  47. clarifai/runners/models/model_class.py +370 -331
  48. clarifai/runners/models/model_run_locally.py +459 -419
  49. clarifai/runners/models/model_runner.py +170 -162
  50. clarifai/runners/models/model_servicer.py +78 -70
  51. clarifai/runners/server.py +111 -101
  52. clarifai/runners/utils/code_script.py +225 -187
  53. clarifai/runners/utils/const.py +4 -1
  54. clarifai/runners/utils/data_types/__init__.py +12 -0
  55. clarifai/runners/utils/data_types/data_types.py +598 -0
  56. clarifai/runners/utils/data_utils.py +387 -440
  57. clarifai/runners/utils/loader.py +247 -227
  58. clarifai/runners/utils/method_signatures.py +411 -386
  59. clarifai/runners/utils/openai_convertor.py +108 -109
  60. clarifai/runners/utils/serializers.py +175 -179
  61. clarifai/runners/utils/url_fetcher.py +35 -35
  62. clarifai/schema/search.py +56 -63
  63. clarifai/urls/helper.py +125 -102
  64. clarifai/utils/cli.py +129 -123
  65. clarifai/utils/config.py +127 -87
  66. clarifai/utils/constants.py +49 -0
  67. clarifai/utils/evaluation/helpers.py +503 -466
  68. clarifai/utils/evaluation/main.py +431 -393
  69. clarifai/utils/evaluation/testset_annotation_parser.py +154 -144
  70. clarifai/utils/logging.py +324 -306
  71. clarifai/utils/misc.py +60 -56
  72. clarifai/utils/model_train.py +165 -146
  73. clarifai/utils/protobuf.py +126 -103
  74. clarifai/versions.py +3 -1
  75. clarifai/workflows/export.py +48 -50
  76. clarifai/workflows/utils.py +39 -36
  77. clarifai/workflows/validate.py +55 -43
  78. {clarifai-11.3.0rc2.dist-info → clarifai-11.4.0.dist-info}/METADATA +16 -6
  79. clarifai-11.4.0.dist-info/RECORD +109 -0
  80. {clarifai-11.3.0rc2.dist-info → clarifai-11.4.0.dist-info}/WHEEL +1 -1
  81. clarifai/__pycache__/__init__.cpython-310.pyc +0 -0
  82. clarifai/__pycache__/__init__.cpython-311.pyc +0 -0
  83. clarifai/__pycache__/__init__.cpython-39.pyc +0 -0
  84. clarifai/__pycache__/errors.cpython-310.pyc +0 -0
  85. clarifai/__pycache__/errors.cpython-311.pyc +0 -0
  86. clarifai/__pycache__/versions.cpython-310.pyc +0 -0
  87. clarifai/__pycache__/versions.cpython-311.pyc +0 -0
  88. clarifai/cli/__pycache__/__init__.cpython-310.pyc +0 -0
  89. clarifai/cli/__pycache__/__init__.cpython-311.pyc +0 -0
  90. clarifai/cli/__pycache__/base.cpython-310.pyc +0 -0
  91. clarifai/cli/__pycache__/base.cpython-311.pyc +0 -0
  92. clarifai/cli/__pycache__/base_cli.cpython-310.pyc +0 -0
  93. clarifai/cli/__pycache__/compute_cluster.cpython-310.pyc +0 -0
  94. clarifai/cli/__pycache__/compute_cluster.cpython-311.pyc +0 -0
  95. clarifai/cli/__pycache__/deployment.cpython-310.pyc +0 -0
  96. clarifai/cli/__pycache__/deployment.cpython-311.pyc +0 -0
  97. clarifai/cli/__pycache__/model.cpython-310.pyc +0 -0
  98. clarifai/cli/__pycache__/model.cpython-311.pyc +0 -0
  99. clarifai/cli/__pycache__/model_cli.cpython-310.pyc +0 -0
  100. clarifai/cli/__pycache__/nodepool.cpython-310.pyc +0 -0
  101. clarifai/cli/__pycache__/nodepool.cpython-311.pyc +0 -0
  102. clarifai/client/__pycache__/__init__.cpython-310.pyc +0 -0
  103. clarifai/client/__pycache__/__init__.cpython-311.pyc +0 -0
  104. clarifai/client/__pycache__/__init__.cpython-39.pyc +0 -0
  105. clarifai/client/__pycache__/app.cpython-310.pyc +0 -0
  106. clarifai/client/__pycache__/app.cpython-311.pyc +0 -0
  107. clarifai/client/__pycache__/app.cpython-39.pyc +0 -0
  108. clarifai/client/__pycache__/base.cpython-310.pyc +0 -0
  109. clarifai/client/__pycache__/base.cpython-311.pyc +0 -0
  110. clarifai/client/__pycache__/compute_cluster.cpython-310.pyc +0 -0
  111. clarifai/client/__pycache__/compute_cluster.cpython-311.pyc +0 -0
  112. clarifai/client/__pycache__/dataset.cpython-310.pyc +0 -0
  113. clarifai/client/__pycache__/dataset.cpython-311.pyc +0 -0
  114. clarifai/client/__pycache__/deployment.cpython-310.pyc +0 -0
  115. clarifai/client/__pycache__/deployment.cpython-311.pyc +0 -0
  116. clarifai/client/__pycache__/input.cpython-310.pyc +0 -0
  117. clarifai/client/__pycache__/input.cpython-311.pyc +0 -0
  118. clarifai/client/__pycache__/lister.cpython-310.pyc +0 -0
  119. clarifai/client/__pycache__/lister.cpython-311.pyc +0 -0
  120. clarifai/client/__pycache__/model.cpython-310.pyc +0 -0
  121. clarifai/client/__pycache__/model.cpython-311.pyc +0 -0
  122. clarifai/client/__pycache__/module.cpython-310.pyc +0 -0
  123. clarifai/client/__pycache__/module.cpython-311.pyc +0 -0
  124. clarifai/client/__pycache__/nodepool.cpython-310.pyc +0 -0
  125. clarifai/client/__pycache__/nodepool.cpython-311.pyc +0 -0
  126. clarifai/client/__pycache__/search.cpython-310.pyc +0 -0
  127. clarifai/client/__pycache__/search.cpython-311.pyc +0 -0
  128. clarifai/client/__pycache__/user.cpython-310.pyc +0 -0
  129. clarifai/client/__pycache__/user.cpython-311.pyc +0 -0
  130. clarifai/client/__pycache__/workflow.cpython-310.pyc +0 -0
  131. clarifai/client/__pycache__/workflow.cpython-311.pyc +0 -0
  132. clarifai/client/auth/__pycache__/__init__.cpython-310.pyc +0 -0
  133. clarifai/client/auth/__pycache__/__init__.cpython-311.pyc +0 -0
  134. clarifai/client/auth/__pycache__/helper.cpython-310.pyc +0 -0
  135. clarifai/client/auth/__pycache__/helper.cpython-311.pyc +0 -0
  136. clarifai/client/auth/__pycache__/register.cpython-310.pyc +0 -0
  137. clarifai/client/auth/__pycache__/register.cpython-311.pyc +0 -0
  138. clarifai/client/auth/__pycache__/stub.cpython-310.pyc +0 -0
  139. clarifai/client/auth/__pycache__/stub.cpython-311.pyc +0 -0
  140. clarifai/client/cli/__init__.py +0 -0
  141. clarifai/client/cli/__pycache__/__init__.cpython-310.pyc +0 -0
  142. clarifai/client/cli/__pycache__/base_cli.cpython-310.pyc +0 -0
  143. clarifai/client/cli/__pycache__/model_cli.cpython-310.pyc +0 -0
  144. clarifai/client/cli/base_cli.py +0 -88
  145. clarifai/client/cli/model_cli.py +0 -29
  146. clarifai/constants/__pycache__/base.cpython-310.pyc +0 -0
  147. clarifai/constants/__pycache__/base.cpython-311.pyc +0 -0
  148. clarifai/constants/__pycache__/dataset.cpython-310.pyc +0 -0
  149. clarifai/constants/__pycache__/dataset.cpython-311.pyc +0 -0
  150. clarifai/constants/__pycache__/input.cpython-310.pyc +0 -0
  151. clarifai/constants/__pycache__/input.cpython-311.pyc +0 -0
  152. clarifai/constants/__pycache__/model.cpython-310.pyc +0 -0
  153. clarifai/constants/__pycache__/model.cpython-311.pyc +0 -0
  154. clarifai/constants/__pycache__/rag.cpython-310.pyc +0 -0
  155. clarifai/constants/__pycache__/rag.cpython-311.pyc +0 -0
  156. clarifai/constants/__pycache__/search.cpython-310.pyc +0 -0
  157. clarifai/constants/__pycache__/search.cpython-311.pyc +0 -0
  158. clarifai/constants/__pycache__/workflow.cpython-310.pyc +0 -0
  159. clarifai/constants/__pycache__/workflow.cpython-311.pyc +0 -0
  160. clarifai/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
  161. clarifai/datasets/__pycache__/__init__.cpython-311.pyc +0 -0
  162. clarifai/datasets/__pycache__/__init__.cpython-39.pyc +0 -0
  163. clarifai/datasets/export/__pycache__/__init__.cpython-310.pyc +0 -0
  164. clarifai/datasets/export/__pycache__/__init__.cpython-311.pyc +0 -0
  165. clarifai/datasets/export/__pycache__/__init__.cpython-39.pyc +0 -0
  166. clarifai/datasets/export/__pycache__/inputs_annotations.cpython-310.pyc +0 -0
  167. clarifai/datasets/export/__pycache__/inputs_annotations.cpython-311.pyc +0 -0
  168. clarifai/datasets/upload/__pycache__/__init__.cpython-310.pyc +0 -0
  169. clarifai/datasets/upload/__pycache__/__init__.cpython-311.pyc +0 -0
  170. clarifai/datasets/upload/__pycache__/__init__.cpython-39.pyc +0 -0
  171. clarifai/datasets/upload/__pycache__/base.cpython-310.pyc +0 -0
  172. clarifai/datasets/upload/__pycache__/base.cpython-311.pyc +0 -0
  173. clarifai/datasets/upload/__pycache__/features.cpython-310.pyc +0 -0
  174. clarifai/datasets/upload/__pycache__/features.cpython-311.pyc +0 -0
  175. clarifai/datasets/upload/__pycache__/image.cpython-310.pyc +0 -0
  176. clarifai/datasets/upload/__pycache__/image.cpython-311.pyc +0 -0
  177. clarifai/datasets/upload/__pycache__/multimodal.cpython-310.pyc +0 -0
  178. clarifai/datasets/upload/__pycache__/multimodal.cpython-311.pyc +0 -0
  179. clarifai/datasets/upload/__pycache__/text.cpython-310.pyc +0 -0
  180. clarifai/datasets/upload/__pycache__/text.cpython-311.pyc +0 -0
  181. clarifai/datasets/upload/__pycache__/utils.cpython-310.pyc +0 -0
  182. clarifai/datasets/upload/__pycache__/utils.cpython-311.pyc +0 -0
  183. clarifai/datasets/upload/loaders/__pycache__/__init__.cpython-311.pyc +0 -0
  184. clarifai/datasets/upload/loaders/__pycache__/__init__.cpython-39.pyc +0 -0
  185. clarifai/datasets/upload/loaders/__pycache__/coco_detection.cpython-311.pyc +0 -0
  186. clarifai/datasets/upload/loaders/__pycache__/imagenet_classification.cpython-311.pyc +0 -0
  187. clarifai/models/__pycache__/__init__.cpython-39.pyc +0 -0
  188. clarifai/modules/__pycache__/__init__.cpython-39.pyc +0 -0
  189. clarifai/rag/__pycache__/__init__.cpython-310.pyc +0 -0
  190. clarifai/rag/__pycache__/__init__.cpython-311.pyc +0 -0
  191. clarifai/rag/__pycache__/__init__.cpython-39.pyc +0 -0
  192. clarifai/rag/__pycache__/rag.cpython-310.pyc +0 -0
  193. clarifai/rag/__pycache__/rag.cpython-311.pyc +0 -0
  194. clarifai/rag/__pycache__/rag.cpython-39.pyc +0 -0
  195. clarifai/rag/__pycache__/utils.cpython-310.pyc +0 -0
  196. clarifai/rag/__pycache__/utils.cpython-311.pyc +0 -0
  197. clarifai/runners/__pycache__/__init__.cpython-310.pyc +0 -0
  198. clarifai/runners/__pycache__/__init__.cpython-311.pyc +0 -0
  199. clarifai/runners/__pycache__/__init__.cpython-39.pyc +0 -0
  200. clarifai/runners/dockerfile_template/Dockerfile.cpu.template +0 -31
  201. clarifai/runners/dockerfile_template/Dockerfile.cuda.template +0 -42
  202. clarifai/runners/dockerfile_template/Dockerfile.nim +0 -71
  203. clarifai/runners/models/__pycache__/__init__.cpython-310.pyc +0 -0
  204. clarifai/runners/models/__pycache__/__init__.cpython-311.pyc +0 -0
  205. clarifai/runners/models/__pycache__/__init__.cpython-39.pyc +0 -0
  206. clarifai/runners/models/__pycache__/base_typed_model.cpython-310.pyc +0 -0
  207. clarifai/runners/models/__pycache__/base_typed_model.cpython-311.pyc +0 -0
  208. clarifai/runners/models/__pycache__/base_typed_model.cpython-39.pyc +0 -0
  209. clarifai/runners/models/__pycache__/model_builder.cpython-311.pyc +0 -0
  210. clarifai/runners/models/__pycache__/model_class.cpython-310.pyc +0 -0
  211. clarifai/runners/models/__pycache__/model_class.cpython-311.pyc +0 -0
  212. clarifai/runners/models/__pycache__/model_run_locally.cpython-310-pytest-7.1.2.pyc +0 -0
  213. clarifai/runners/models/__pycache__/model_run_locally.cpython-310.pyc +0 -0
  214. clarifai/runners/models/__pycache__/model_run_locally.cpython-311.pyc +0 -0
  215. clarifai/runners/models/__pycache__/model_runner.cpython-310.pyc +0 -0
  216. clarifai/runners/models/__pycache__/model_runner.cpython-311.pyc +0 -0
  217. clarifai/runners/models/__pycache__/model_upload.cpython-310.pyc +0 -0
  218. clarifai/runners/models/base_typed_model.py +0 -238
  219. clarifai/runners/models/model_class_refract.py +0 -80
  220. clarifai/runners/models/model_upload.py +0 -607
  221. clarifai/runners/models/temp.py +0 -25
  222. clarifai/runners/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  223. clarifai/runners/utils/__pycache__/__init__.cpython-311.pyc +0 -0
  224. clarifai/runners/utils/__pycache__/__init__.cpython-38.pyc +0 -0
  225. clarifai/runners/utils/__pycache__/__init__.cpython-39.pyc +0 -0
  226. clarifai/runners/utils/__pycache__/buffered_stream.cpython-310.pyc +0 -0
  227. clarifai/runners/utils/__pycache__/buffered_stream.cpython-38.pyc +0 -0
  228. clarifai/runners/utils/__pycache__/buffered_stream.cpython-39.pyc +0 -0
  229. clarifai/runners/utils/__pycache__/const.cpython-310.pyc +0 -0
  230. clarifai/runners/utils/__pycache__/const.cpython-311.pyc +0 -0
  231. clarifai/runners/utils/__pycache__/constants.cpython-310.pyc +0 -0
  232. clarifai/runners/utils/__pycache__/constants.cpython-38.pyc +0 -0
  233. clarifai/runners/utils/__pycache__/constants.cpython-39.pyc +0 -0
  234. clarifai/runners/utils/__pycache__/data_handler.cpython-310.pyc +0 -0
  235. clarifai/runners/utils/__pycache__/data_handler.cpython-311.pyc +0 -0
  236. clarifai/runners/utils/__pycache__/data_handler.cpython-38.pyc +0 -0
  237. clarifai/runners/utils/__pycache__/data_handler.cpython-39.pyc +0 -0
  238. clarifai/runners/utils/__pycache__/data_utils.cpython-310.pyc +0 -0
  239. clarifai/runners/utils/__pycache__/data_utils.cpython-311.pyc +0 -0
  240. clarifai/runners/utils/__pycache__/data_utils.cpython-38.pyc +0 -0
  241. clarifai/runners/utils/__pycache__/data_utils.cpython-39.pyc +0 -0
  242. clarifai/runners/utils/__pycache__/grpc_server.cpython-310.pyc +0 -0
  243. clarifai/runners/utils/__pycache__/grpc_server.cpython-38.pyc +0 -0
  244. clarifai/runners/utils/__pycache__/grpc_server.cpython-39.pyc +0 -0
  245. clarifai/runners/utils/__pycache__/health.cpython-310.pyc +0 -0
  246. clarifai/runners/utils/__pycache__/health.cpython-38.pyc +0 -0
  247. clarifai/runners/utils/__pycache__/health.cpython-39.pyc +0 -0
  248. clarifai/runners/utils/__pycache__/loader.cpython-310.pyc +0 -0
  249. clarifai/runners/utils/__pycache__/loader.cpython-311.pyc +0 -0
  250. clarifai/runners/utils/__pycache__/logging.cpython-310.pyc +0 -0
  251. clarifai/runners/utils/__pycache__/logging.cpython-38.pyc +0 -0
  252. clarifai/runners/utils/__pycache__/logging.cpython-39.pyc +0 -0
  253. clarifai/runners/utils/__pycache__/stream_source.cpython-310.pyc +0 -0
  254. clarifai/runners/utils/__pycache__/stream_source.cpython-39.pyc +0 -0
  255. clarifai/runners/utils/__pycache__/url_fetcher.cpython-310.pyc +0 -0
  256. clarifai/runners/utils/__pycache__/url_fetcher.cpython-311.pyc +0 -0
  257. clarifai/runners/utils/__pycache__/url_fetcher.cpython-38.pyc +0 -0
  258. clarifai/runners/utils/__pycache__/url_fetcher.cpython-39.pyc +0 -0
  259. clarifai/runners/utils/data_handler.py +0 -231
  260. clarifai/runners/utils/data_handler_refract.py +0 -213
  261. clarifai/runners/utils/data_types.py +0 -469
  262. clarifai/runners/utils/logger.py +0 -0
  263. clarifai/runners/utils/openai_format.py +0 -87
  264. clarifai/schema/__pycache__/search.cpython-310.pyc +0 -0
  265. clarifai/schema/__pycache__/search.cpython-311.pyc +0 -0
  266. clarifai/urls/__pycache__/helper.cpython-310.pyc +0 -0
  267. clarifai/urls/__pycache__/helper.cpython-311.pyc +0 -0
  268. clarifai/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  269. clarifai/utils/__pycache__/__init__.cpython-311.pyc +0 -0
  270. clarifai/utils/__pycache__/__init__.cpython-39.pyc +0 -0
  271. clarifai/utils/__pycache__/cli.cpython-310.pyc +0 -0
  272. clarifai/utils/__pycache__/cli.cpython-311.pyc +0 -0
  273. clarifai/utils/__pycache__/config.cpython-311.pyc +0 -0
  274. clarifai/utils/__pycache__/constants.cpython-310.pyc +0 -0
  275. clarifai/utils/__pycache__/constants.cpython-311.pyc +0 -0
  276. clarifai/utils/__pycache__/logging.cpython-310.pyc +0 -0
  277. clarifai/utils/__pycache__/logging.cpython-311.pyc +0 -0
  278. clarifai/utils/__pycache__/misc.cpython-310.pyc +0 -0
  279. clarifai/utils/__pycache__/misc.cpython-311.pyc +0 -0
  280. clarifai/utils/__pycache__/model_train.cpython-310.pyc +0 -0
  281. clarifai/utils/__pycache__/model_train.cpython-311.pyc +0 -0
  282. clarifai/utils/__pycache__/protobuf.cpython-311.pyc +0 -0
  283. clarifai/utils/evaluation/__pycache__/__init__.cpython-311.pyc +0 -0
  284. clarifai/utils/evaluation/__pycache__/__init__.cpython-39.pyc +0 -0
  285. clarifai/utils/evaluation/__pycache__/helpers.cpython-311.pyc +0 -0
  286. clarifai/utils/evaluation/__pycache__/main.cpython-311.pyc +0 -0
  287. clarifai/utils/evaluation/__pycache__/main.cpython-39.pyc +0 -0
  288. clarifai/workflows/__pycache__/__init__.cpython-310.pyc +0 -0
  289. clarifai/workflows/__pycache__/__init__.cpython-311.pyc +0 -0
  290. clarifai/workflows/__pycache__/__init__.cpython-39.pyc +0 -0
  291. clarifai/workflows/__pycache__/export.cpython-310.pyc +0 -0
  292. clarifai/workflows/__pycache__/export.cpython-311.pyc +0 -0
  293. clarifai/workflows/__pycache__/utils.cpython-310.pyc +0 -0
  294. clarifai/workflows/__pycache__/utils.cpython-311.pyc +0 -0
  295. clarifai/workflows/__pycache__/validate.cpython-310.pyc +0 -0
  296. clarifai/workflows/__pycache__/validate.cpython-311.pyc +0 -0
  297. clarifai-11.3.0rc2.dist-info/RECORD +0 -322
  298. {clarifai-11.3.0rc2.dist-info → clarifai-11.4.0.dist-info}/entry_points.txt +0 -0
  299. {clarifai-11.3.0rc2.dist-info → clarifai-11.4.0.dist-info/licenses}/LICENSE +0 -0
  300. {clarifai-11.3.0rc2.dist-info → clarifai-11.4.0.dist-info}/top_level.txt +0 -0
@@ -5,422 +5,460 @@ from typing import List, Tuple, Union
5
5
  from clarifai.client.dataset import Dataset
6
6
  from clarifai.client.model import Model
7
7
 
8
- from .helpers import (MACRO_AVG, EvalType, _BaseEvalResultHandler, get_eval_type,
9
- make_handler_by_type)
8
+ from .helpers import (
9
+ MACRO_AVG,
10
+ EvalType,
11
+ _BaseEvalResultHandler,
12
+ get_eval_type,
13
+ make_handler_by_type,
14
+ )
10
15
 
11
16
  try:
12
- import seaborn as sns
17
+ import seaborn as sns
13
18
  except ImportError:
14
- raise ImportError("Can not import seaborn. Please run `pip install seaborn` to install it")
19
+ raise ImportError("Can not import seaborn. Please run `pip install seaborn` to install it")
15
20
 
16
21
  try:
17
- import matplotlib.pyplot as plt
22
+ import matplotlib.pyplot as plt
18
23
  except ImportError:
19
- raise ImportError("Can not import matplotlib. Please run `pip install matplotlib` to install it")
24
+ raise ImportError(
25
+ "Can not import matplotlib. Please run `pip install matplotlib` to install it"
26
+ )
20
27
 
21
28
  try:
22
- import pandas as pd
29
+ import pandas as pd
23
30
  except ImportError:
24
- raise ImportError("Can not import pandas. Please run `pip install pandas` to install it")
31
+ raise ImportError("Can not import pandas. Please run `pip install pandas` to install it")
25
32
 
26
33
  try:
27
- from loguru import logger
34
+ from loguru import logger
28
35
  except ImportError:
29
- from ..logging import logger
36
+ from ..logging import logger
30
37
 
31
38
  __all__ = ['EvalResultCompare']
32
39
 
33
40
 
34
41
  class CompareMode(Enum):
35
- MANY_MODELS_TO_ONE_DATA = 0
36
- ONE_MODEL_TO_MANY_DATA = 1
42
+ MANY_MODELS_TO_ONE_DATA = 0
43
+ ONE_MODEL_TO_MANY_DATA = 1
37
44
 
38
45
 
39
46
  class EvalResultCompare:
40
- """Compare evaluation result of models against datasets.
41
- Note: The module will pick latest result on the datasets.
42
- and models must be same model type
43
-
44
- Args:
45
- ---
46
- models (Union[List[Model], List[str]]): List of Model or urls of models.
47
- datasets (Union[Dataset, List[Dataset], str, List[str]]): A single or List of Url or Dataset
48
- attempt_evaluate (bool): Evaluate when model is not evaluated with the datasets.
49
- auth_kwargs (dict): Additional auth keyword arguments to be passed to the Dataset and Model if using url(s)
50
- """
51
-
52
- def __init__(self,
53
- models: Union[List[Model], List[str]],
54
- datasets: Union[Dataset, List[Dataset], str, List[str]],
55
- attempt_evaluate: bool = False,
56
- eval_info: dict = None,
57
- auth_kwargs: dict = {}):
58
- assert isinstance(models, list), ValueError("Expected list")
59
-
60
- if len(models) > 1:
61
- self.mode = CompareMode.MANY_MODELS_TO_ONE_DATA
62
- self.comparator = "Model"
63
- assert isinstance(datasets, Dataset) or (
64
- isinstance(datasets, list) and len(datasets) == 1
65
- ), f"When comparing multiple models, must provide only one `datasets`. However got {datasets}"
66
- else:
67
- self.mode = CompareMode.ONE_MODEL_TO_MANY_DATA
68
- self.comparator = "Dataset"
69
-
70
- # validate models
71
- if all(map(lambda x: isinstance(x, str), models)):
72
- models = [Model(each, **auth_kwargs) for each in models]
73
- elif not all(map(lambda x: isinstance(x, Model), models)):
74
- raise ValueError(
75
- f"Expected all models are list of string or list of Model, got {[type(each) for each in models]}"
76
- )
77
- # validate datasets
78
- if not isinstance(datasets, list):
79
- datasets = [
80
- datasets,
81
- ]
82
- if all(map(lambda x: isinstance(x, str), datasets)):
83
- datasets = [Dataset(each, **auth_kwargs) for each in datasets]
84
- elif not all(map(lambda x: isinstance(x, Dataset), datasets)):
85
- raise ValueError(
86
- f"Expected datasets must be str, list of string or Dataset, list of Dataset, got {[type(each) for each in datasets]}"
87
- )
88
- # Validate models vs datasets together
89
- self._eval_handlers: List[_BaseEvalResultHandler] = []
90
- self.model_type = None
91
- logger.info("Initializing models...")
92
- for model in models:
93
- model.load_info()
94
- model_type = model.model_info.model_type_id
95
- if not self.model_type:
96
- self.model_type = model_type
97
- else:
98
- assert self.model_type == model_type, f"Can not compare when model types are different, {self.model_type} != {model_type}"
99
- m = make_handler_by_type(model_type)(model=model)
100
- logger.info(f"* {m.get_model_name(pretify=True)}")
101
- m.find_eval_id(datasets=datasets, attempt_evaluate=attempt_evaluate, eval_info=eval_info)
102
- self._eval_handlers.append(m)
103
-
104
- @property
105
- def eval_handlers(self):
106
- return self._eval_handlers
107
-
108
- def _loop_eval_handlers(self, func_name: str, **kwargs) -> Tuple[list, list]:
109
- """ Run methods of `eval_handlers[...].model`
47
+ """Compare evaluation result of models against datasets.
48
+ Note: The module will pick latest result on the datasets.
49
+ and models must be same model type
110
50
 
111
51
  Args:
112
- func_name (str): method name, see `_BaseEvalResultHandler` child classes
113
- kwargs: keyword arguments of the method
114
-
115
- Return:
116
- tuple:
117
- - list of outputs
118
- - list of comparator names
119
-
52
+ ---
53
+ models (Union[List[Model], List[str]]): List of Model or urls of models.
54
+ datasets (Union[Dataset, List[Dataset], str, List[str]]): A single or List of Url or Dataset
55
+ attempt_evaluate (bool): Evaluate when model is not evaluated with the datasets.
56
+ auth_kwargs (dict): Additional auth keyword arguments to be passed to the Dataset and Model if using url(s)
120
57
  """
121
- outs = []
122
- comparators = []
123
- logger.info(f'Running `{func_name}`')
124
- for _, each in enumerate(self.eval_handlers):
125
- for ds_index, _ in enumerate(each.eval_data):
126
- func = eval(f'each.{func_name}')
127
- out = func(index=ds_index, **kwargs)
128
58
 
129
- if self.mode == CompareMode.MANY_MODELS_TO_ONE_DATA:
130
- name = each.get_model_name(pretify=True)
59
+ def __init__(
60
+ self,
61
+ models: Union[List[Model], List[str]],
62
+ datasets: Union[Dataset, List[Dataset], str, List[str]],
63
+ attempt_evaluate: bool = False,
64
+ eval_info: dict = None,
65
+ auth_kwargs: dict = {},
66
+ ):
67
+ assert isinstance(models, list), ValueError("Expected list")
68
+
69
+ if len(models) > 1:
70
+ self.mode = CompareMode.MANY_MODELS_TO_ONE_DATA
71
+ self.comparator = "Model"
72
+ assert isinstance(datasets, Dataset) or (
73
+ isinstance(datasets, list) and len(datasets) == 1
74
+ ), (
75
+ f"When comparing multiple models, must provide only one `datasets`. However got {datasets}"
76
+ )
131
77
  else:
132
- name = each.get_dataset_name_by_index(ds_index, pretify=True)
133
- if out is None:
134
- logger.warning(f'{self.comparator}:{name} does not have valid data for `{func_name}`')
135
- continue
136
- comparators.append(name)
137
- outs.append(out)
138
-
139
- if self.mode == CompareMode.MANY_MODELS_TO_ONE_DATA:
140
- apps = set([comp.split('/')[0] for comp in comparators])
141
- if len(apps) == 1:
142
- comparators = ['/'.join(comp.split('/')[1:]) for comp in comparators]
143
-
144
- if not outs:
145
- logger.warning(f'Model type {self.model_type} does not support `{func_name}`')
146
-
147
- return outs, comparators
148
-
149
- def detailed_summary(self,
150
- confidence_threshold: float = .5,
151
- iou_threshold: float = .5,
152
- area: str = "all",
153
- bypass_const=False) -> Union[Tuple[pd.DataFrame, pd.DataFrame], None]:
154
- """
155
- Retrieve and compute popular metrics of model.
156
-
157
- Args:
158
- confidence_threshold (float): confidence threshold, applicable for classification and detection. Default is 0.5
159
- iou_threshold (float): iou threshold, support in range(0.5, 1., step=0.1) applicable for detection
160
- area (float): size of area, support {all, small, medium}, applicable for detection
161
-
162
- Return:
163
- None or tuple of dataframe: df summary per concept and total concepts
78
+ self.mode = CompareMode.ONE_MODEL_TO_MANY_DATA
79
+ self.comparator = "Dataset"
80
+
81
+ # validate models
82
+ if all(map(lambda x: isinstance(x, str), models)):
83
+ models = [Model(each, **auth_kwargs) for each in models]
84
+ elif not all(map(lambda x: isinstance(x, Model), models)):
85
+ raise ValueError(
86
+ f"Expected all models are list of string or list of Model, got {[type(each) for each in models]}"
87
+ )
88
+ # validate datasets
89
+ if not isinstance(datasets, list):
90
+ datasets = [
91
+ datasets,
92
+ ]
93
+ if all(map(lambda x: isinstance(x, str), datasets)):
94
+ datasets = [Dataset(each, **auth_kwargs) for each in datasets]
95
+ elif not all(map(lambda x: isinstance(x, Dataset), datasets)):
96
+ raise ValueError(
97
+ f"Expected datasets must be str, list of string or Dataset, list of Dataset, got {[type(each) for each in datasets]}"
98
+ )
99
+ # Validate models vs datasets together
100
+ self._eval_handlers: List[_BaseEvalResultHandler] = []
101
+ self.model_type = None
102
+ logger.info("Initializing models...")
103
+ for model in models:
104
+ model.load_info()
105
+ model_type = model.model_info.model_type_id
106
+ if not self.model_type:
107
+ self.model_type = model_type
108
+ else:
109
+ assert self.model_type == model_type, (
110
+ f"Can not compare when model types are different, {self.model_type} != {model_type}"
111
+ )
112
+ m = make_handler_by_type(model_type)(model=model)
113
+ logger.info(f"* {m.get_model_name(pretify=True)}")
114
+ m.find_eval_id(
115
+ datasets=datasets, attempt_evaluate=attempt_evaluate, eval_info=eval_info
116
+ )
117
+ self._eval_handlers.append(m)
118
+
119
+ @property
120
+ def eval_handlers(self):
121
+ return self._eval_handlers
122
+
123
+ def _loop_eval_handlers(self, func_name: str, **kwargs) -> Tuple[list, list]:
124
+ """Run methods of `eval_handlers[...].model`
125
+
126
+ Args:
127
+ func_name (str): method name, see `_BaseEvalResultHandler` child classes
128
+ kwargs: keyword arguments of the method
129
+
130
+ Return:
131
+ tuple:
132
+ - list of outputs
133
+ - list of comparator names
134
+
135
+ """
136
+ outs = []
137
+ comparators = []
138
+ logger.info(f'Running `{func_name}`')
139
+ for _, each in enumerate(self.eval_handlers):
140
+ for ds_index, _ in enumerate(each.eval_data):
141
+ func = eval(f'each.{func_name}')
142
+ out = func(index=ds_index, **kwargs)
143
+
144
+ if self.mode == CompareMode.MANY_MODELS_TO_ONE_DATA:
145
+ name = each.get_model_name(pretify=True)
146
+ else:
147
+ name = each.get_dataset_name_by_index(ds_index, pretify=True)
148
+ if out is None:
149
+ logger.warning(
150
+ f'{self.comparator}:{name} does not have valid data for `{func_name}`'
151
+ )
152
+ continue
153
+ comparators.append(name)
154
+ outs.append(out)
164
155
 
165
- """
166
- df = []
167
- total = []
168
- # loop over all eval_handlers/dataset and call its method
169
- outs, comparators = self._loop_eval_handlers(
170
- 'detailed_summary',
171
- confidence_threshold=confidence_threshold,
172
- iou_threshold=iou_threshold,
173
- area=area,
174
- bypass_const=bypass_const)
175
- for indx, out in enumerate(outs):
176
- _df, _total = out
177
- _df[self.comparator] = [comparators[indx] for _ in range(len(_df))]
178
- _total['Concept'].replace(
179
- to_replace=['Total'], value=f'{self.comparator}:{comparators[indx]}', inplace=True)
180
- _total.rename({'Concept': 'Total Concept'}, axis=1, inplace=True)
181
- df.append(_df)
182
- total.append(_total)
183
-
184
- if df:
185
- df = pd.concat(df, axis=0)
186
- total = pd.concat(total, axis=0)
187
- return df, total
188
- else:
189
- return None
190
-
191
- def confusion_matrix(self, show=True, save_path: str = None,
192
- cm_kwargs: dict = {}) -> Union[pd.DataFrame, None]:
193
- """Return dataframe of confusion matrix
194
- Args:
195
- show (bool, optional): Show the chart. Defaults to True.
196
- save_path (str): path to save rendered chart.
197
- cm_kwargs (dict): keyword args of `eval_handler[...].model.cm_kwargs` method.
198
- Returns:
199
- None or pd.Dataframe, If models don't have confusion matrix, return None
200
- """
201
- outs, comparators = self._loop_eval_handlers("confusion_matrix", **cm_kwargs)
202
- all_dfs = []
203
- for _, (df, anchor) in enumerate(zip(outs, comparators)):
204
- df[self.comparator] = [anchor for _ in range(len(df))]
205
- all_dfs.append(df)
206
-
207
- if all_dfs:
208
- all_dfs = pd.concat(all_dfs, axis=0)
209
- if save_path or show:
210
-
211
- def _facet_heatmap(data, **kws):
212
- data = data.dropna(axis=1)
213
- data = data.drop(self.comparator, axis=1)
214
- concepts = data.columns
215
- colnames = pd.MultiIndex.from_arrays([concepts], names=['Predicted'])
216
- data.columns = colnames
217
- ax = sns.heatmap(data, cmap='Blues', annot=True, annot_kws={"fontsize": 8}, **kws)
218
- ax.set_xticklabels(ax.get_xticklabels(), rotation=45, fontsize=6)
219
- ax.set_yticklabels(ax.get_yticklabels(), fontsize=6, rotation=0)
220
-
221
- temp = all_dfs.copy()
222
- temp.columns = ["_".join(pair) for pair in temp.columns]
223
- with sns.plotting_context(font_scale=5.5):
224
- g = sns.FacetGrid(
225
- temp,
226
- col=self.comparator,
227
- col_wrap=3,
228
- aspect=1,
229
- height=3,
230
- sharex=False,
231
- sharey=False,
232
- )
233
- cbar_ax = g.figure.add_axes([.92, .3, .02, .4])
234
- g = g.map_dataframe(
235
- _facet_heatmap, cbar_ax=cbar_ax, vmin=0, vmax=1, cbar=True, square=True)
236
- g.set_titles(col_template=str(self.comparator) + ':{col_name}', fontsize=5)
237
- if show:
238
- plt.show()
239
- if save_path:
240
- g.savefig(save_path)
241
-
242
- return all_dfs if isinstance(all_dfs, pd.DataFrame) else None
243
-
244
- @staticmethod
245
- def _set_default_kwargs(kwargs: dict, var_name: str, value):
246
- if var_name not in kwargs:
247
- kwargs.update({var_name: value})
248
- return kwargs
249
-
250
- @staticmethod
251
- def _setup_default_lineplot(df: pd.DataFrame, kwargs: dict):
252
- hue_order = df["concept"].unique().tolist()
253
- hue_order.remove(MACRO_AVG)
254
- hue_order.insert(0, MACRO_AVG)
255
- EvalResultCompare._set_default_kwargs(kwargs, "hue_order", hue_order)
256
-
257
- sizes = {}
258
- for each in hue_order:
259
- s = 1.5
260
- if each == MACRO_AVG:
261
- s = 4.
262
- sizes.update({each: s})
263
- EvalResultCompare._set_default_kwargs(kwargs, "sizes", sizes)
264
- EvalResultCompare._set_default_kwargs(kwargs, "size", "concept")
265
-
266
- EvalResultCompare._set_default_kwargs(kwargs, "errorbar", None)
267
- EvalResultCompare._set_default_kwargs(kwargs, "height", 5)
268
-
269
- return kwargs
270
-
271
- def roc_curve_plot(self,
272
- show=True,
273
- save_path: str = None,
274
- roc_curve_kwargs: dict = {},
275
- relplot_kwargs: dict = {}) -> Union[pd.DataFrame, None]:
276
- """Return dataframe of ROC curve
277
- Args:
278
- show (bool, optional): Show the chart. Defaults to True.
279
- save_path (str): path to save rendered chart.
280
- pr_curve_kwargs (dict): keyword args of `eval_handler[...].model.roc_curve` method.
281
- relplot_kwargs (dict): keyword args of `sns.relplot` except {data,x,y,hue,kind,col}. where x="fpr", y="tpr", hue="concept"
282
- Returns:
283
- None or pd.Dataframe, If models don't have ROC curve, return None
284
- """
285
- sns.set_palette("Paired")
286
- outs, comparator = self._loop_eval_handlers("roc_curve", **roc_curve_kwargs)
287
- all_dfs = []
288
- for _, (df, anchor) in enumerate(zip(outs, comparator)):
289
- df[self.comparator] = [anchor for _ in range(len(df))]
290
- all_dfs.append(df)
291
-
292
- if all_dfs:
293
- all_dfs = pd.concat(all_dfs, axis=0)
294
- if save_path or show:
295
- relplot_kwargs = self._setup_default_lineplot(all_dfs, relplot_kwargs)
296
- g = sns.relplot(
297
- data=all_dfs,
298
- x="fpr",
299
- y="tpr",
300
- hue='concept',
301
- kind="line",
302
- col=self.comparator,
303
- **relplot_kwargs)
304
- g.set_titles(col_template=str(self.comparator) + ':{col_name}', fontsize=5)
305
- if show:
306
- plt.show()
307
- if save_path:
308
- g.savefig(save_path)
309
-
310
- return all_dfs if isinstance(all_dfs, pd.DataFrame) else None
311
-
312
- def pr_plot(self,
313
- show=True,
314
- save_path: str = None,
315
- pr_curve_kwargs: dict = {},
316
- relplot_kwargs: dict = {}) -> Union[pd.DataFrame, None]:
317
- """Return dataframe of PR curve
318
- Args:
319
- show (bool, optional): Show the chart. Defaults to True.
320
- save_path (str): path to save rendered chart.
321
- pr_curve_kwargs (dict): keyword args of `eval_handler[...].model.pr_curve` method.
322
- relplot_kwargs (dict): keyword args of `sns.relplot` except {data,x,y,hue,kind,col} where x="recall", y="precision", hue="concept"
323
- Returns:
324
- None or pd.Dataframe, If models don't have PR curve, return None
325
- """
326
- sns.set_palette("Paired")
327
- outs, comparator = self._loop_eval_handlers("pr_curve", **pr_curve_kwargs)
328
- all_dfs = []
329
- for _, (df, anchor) in enumerate(zip(outs, comparator)):
330
- df[self.comparator] = [anchor for _ in range(len(df))]
331
- all_dfs.append(df)
332
-
333
- if all_dfs:
334
- all_dfs = pd.concat(all_dfs, axis=0)
335
- if save_path or show:
336
- relplot_kwargs = self._setup_default_lineplot(all_dfs, relplot_kwargs)
337
- g = sns.relplot(
338
- data=all_dfs,
339
- x="recall",
340
- y="precision",
341
- hue='concept',
342
- kind="line",
343
- col=self.comparator,
344
- **relplot_kwargs)
345
- g.set_titles(col_template=str(self.comparator) + ':{col_name}', fontsize=5)
346
- if show:
347
- plt.show()
348
- if save_path:
349
- g.savefig(save_path)
350
-
351
- return all_dfs if isinstance(all_dfs, pd.DataFrame) else None
352
-
353
- def all(
354
- self,
355
- output_folder: str,
356
- confidence_threshold: float = 0.5,
357
- iou_threshold: float = 0.5,
358
- overwrite: bool = False,
359
- metric_kwargs: dict = {},
360
- pr_plot_kwargs: dict = {},
361
- roc_plot_kwargs: dict = {},
362
- ):
363
- """Run all comparison methods one by one:
364
- - detailed_summary
365
- - pr_curve (if applicable)
366
- - pr_plot
367
- - confusion_matrix (if applicable)
368
- And save to output_folder
369
-
370
- Args:
371
- output_folder (str): path to output
372
- confidence_threshold (float): confidence threshold, applicable for classification and detection. Default is 0.5.
373
- iou_threshold (float): iou threshold, support in range(0.5, 1., step=0.1) applicable for detection.
374
- overwrite (bool): overwrite result of output_folder.
375
- metric_kwargs (dict): keyword args for `eval_handler[...].model.{method}`, except for {confidence_threshold, iou_threshold}.
376
- roc_plot_kwargs (dict): for relplot_kwargs of `roc_curve_plot` method.
377
- pr_plot_kwargs (dict): for relplot_kwargs of `pr_plot` method.
378
- """
379
- eval_type = get_eval_type(self.model_type)
380
- area = metric_kwargs.pop("area", "all")
381
- bypass_const = metric_kwargs.pop("bypass_const", False)
382
-
383
- fname = f"conf-{confidence_threshold}"
384
- if eval_type == EvalType.DETECTION:
385
- fname = f"{fname}_iou-{iou_threshold}_area-{area}"
386
-
387
- def join_root(*args):
388
- return os.path.join(output_folder, *args)
389
-
390
- output_folder = join_root(fname)
391
- if os.path.exists(output_folder) and not overwrite:
392
- raise RuntimeError(f"{output_folder} exists. If you want to overwrite, set `overwrite=True`")
393
-
394
- os.makedirs(output_folder, exist_ok=True)
395
-
396
- logger.info("Making summary tables...")
397
- dfs = self.detailed_summary(
398
- confidence_threshold=confidence_threshold,
399
- iou_threshold=iou_threshold,
400
- area=area,
401
- bypass_const=bypass_const)
402
- if dfs is not None:
403
- concept_df, total_df = dfs
404
- concept_df.to_csv(join_root("concepts_summary.csv"))
405
- total_df.to_csv(join_root("total_summary.csv"))
406
-
407
- curve_metric_kwargs = dict(
408
- confidence_threshold=confidence_threshold, iou_threshold=iou_threshold)
409
- curve_metric_kwargs.update(metric_kwargs)
410
-
411
- self.roc_curve_plot(
412
- show=False,
413
- save_path=join_root("roc.jpg"),
414
- roc_curve_kwargs=curve_metric_kwargs,
415
- relplot_kwargs=roc_plot_kwargs)
416
-
417
- self.pr_plot(
418
- show=False,
419
- save_path=join_root("pr.jpg"),
420
- pr_curve_kwargs=curve_metric_kwargs,
421
- relplot_kwargs=pr_plot_kwargs)
422
-
423
- self.confusion_matrix(
424
- show=False, save_path=join_root("confusion_matrix.jpg"), cm_kwargs=curve_metric_kwargs)
425
-
426
- logger.info(f"Done. Your outputs are saved at {output_folder}")
156
+ if self.mode == CompareMode.MANY_MODELS_TO_ONE_DATA:
157
+ apps = set([comp.split('/')[0] for comp in comparators])
158
+ if len(apps) == 1:
159
+ comparators = ['/'.join(comp.split('/')[1:]) for comp in comparators]
160
+
161
+ if not outs:
162
+ logger.warning(f'Model type {self.model_type} does not support `{func_name}`')
163
+
164
+ return outs, comparators
165
+
166
+ def detailed_summary(
167
+ self,
168
+ confidence_threshold: float = 0.5,
169
+ iou_threshold: float = 0.5,
170
+ area: str = "all",
171
+ bypass_const=False,
172
+ ) -> Union[Tuple[pd.DataFrame, pd.DataFrame], None]:
173
+ """
174
+ Retrieve and compute popular metrics of model.
175
+
176
+ Args:
177
+ confidence_threshold (float): confidence threshold, applicable for classification and detection. Default is 0.5
178
+ iou_threshold (float): iou threshold, support in range(0.5, 1., step=0.1) applicable for detection
179
+ area (float): size of area, support {all, small, medium}, applicable for detection
180
+
181
+ Return:
182
+ None or tuple of dataframe: df summary per concept and total concepts
183
+
184
+ """
185
+ df = []
186
+ total = []
187
+ # loop over all eval_handlers/dataset and call its method
188
+ outs, comparators = self._loop_eval_handlers(
189
+ 'detailed_summary',
190
+ confidence_threshold=confidence_threshold,
191
+ iou_threshold=iou_threshold,
192
+ area=area,
193
+ bypass_const=bypass_const,
194
+ )
195
+ for indx, out in enumerate(outs):
196
+ _df, _total = out
197
+ _df[self.comparator] = [comparators[indx] for _ in range(len(_df))]
198
+ _total['Concept'].replace(
199
+ to_replace=['Total'], value=f'{self.comparator}:{comparators[indx]}', inplace=True
200
+ )
201
+ _total.rename({'Concept': 'Total Concept'}, axis=1, inplace=True)
202
+ df.append(_df)
203
+ total.append(_total)
204
+
205
+ if df:
206
+ df = pd.concat(df, axis=0)
207
+ total = pd.concat(total, axis=0)
208
+ return df, total
209
+ else:
210
+ return None
211
+
212
+ def confusion_matrix(
213
+ self, show=True, save_path: str = None, cm_kwargs: dict = {}
214
+ ) -> Union[pd.DataFrame, None]:
215
+ """Return dataframe of confusion matrix
216
+ Args:
217
+ show (bool, optional): Show the chart. Defaults to True.
218
+ save_path (str): path to save rendered chart.
219
+ cm_kwargs (dict): keyword args of `eval_handler[...].model.cm_kwargs` method.
220
+ Returns:
221
+ None or pd.Dataframe, If models don't have confusion matrix, return None
222
+ """
223
+ outs, comparators = self._loop_eval_handlers("confusion_matrix", **cm_kwargs)
224
+ all_dfs = []
225
+ for _, (df, anchor) in enumerate(zip(outs, comparators)):
226
+ df[self.comparator] = [anchor for _ in range(len(df))]
227
+ all_dfs.append(df)
228
+
229
+ if all_dfs:
230
+ all_dfs = pd.concat(all_dfs, axis=0)
231
+ if save_path or show:
232
+
233
+ def _facet_heatmap(data, **kws):
234
+ data = data.dropna(axis=1)
235
+ data = data.drop(self.comparator, axis=1)
236
+ concepts = data.columns
237
+ colnames = pd.MultiIndex.from_arrays([concepts], names=['Predicted'])
238
+ data.columns = colnames
239
+ ax = sns.heatmap(
240
+ data, cmap='Blues', annot=True, annot_kws={"fontsize": 8}, **kws
241
+ )
242
+ ax.set_xticklabels(ax.get_xticklabels(), rotation=45, fontsize=6)
243
+ ax.set_yticklabels(ax.get_yticklabels(), fontsize=6, rotation=0)
244
+
245
+ temp = all_dfs.copy()
246
+ temp.columns = ["_".join(pair) for pair in temp.columns]
247
+ with sns.plotting_context(font_scale=5.5):
248
+ g = sns.FacetGrid(
249
+ temp,
250
+ col=self.comparator,
251
+ col_wrap=3,
252
+ aspect=1,
253
+ height=3,
254
+ sharex=False,
255
+ sharey=False,
256
+ )
257
+ cbar_ax = g.figure.add_axes([0.92, 0.3, 0.02, 0.4])
258
+ g = g.map_dataframe(
259
+ _facet_heatmap, cbar_ax=cbar_ax, vmin=0, vmax=1, cbar=True, square=True
260
+ )
261
+ g.set_titles(col_template=str(self.comparator) + ':{col_name}', fontsize=5)
262
+ if show:
263
+ plt.show()
264
+ if save_path:
265
+ g.savefig(save_path)
266
+
267
+ return all_dfs if isinstance(all_dfs, pd.DataFrame) else None
268
+
269
+ @staticmethod
270
+ def _set_default_kwargs(kwargs: dict, var_name: str, value):
271
+ if var_name not in kwargs:
272
+ kwargs.update({var_name: value})
273
+ return kwargs
274
+
275
+ @staticmethod
276
+ def _setup_default_lineplot(df: pd.DataFrame, kwargs: dict):
277
+ hue_order = df["concept"].unique().tolist()
278
+ hue_order.remove(MACRO_AVG)
279
+ hue_order.insert(0, MACRO_AVG)
280
+ EvalResultCompare._set_default_kwargs(kwargs, "hue_order", hue_order)
281
+
282
+ sizes = {}
283
+ for each in hue_order:
284
+ s = 1.5
285
+ if each == MACRO_AVG:
286
+ s = 4.0
287
+ sizes.update({each: s})
288
+ EvalResultCompare._set_default_kwargs(kwargs, "sizes", sizes)
289
+ EvalResultCompare._set_default_kwargs(kwargs, "size", "concept")
290
+
291
+ EvalResultCompare._set_default_kwargs(kwargs, "errorbar", None)
292
+ EvalResultCompare._set_default_kwargs(kwargs, "height", 5)
293
+
294
+ return kwargs
295
+
296
+ def roc_curve_plot(
297
+ self,
298
+ show=True,
299
+ save_path: str = None,
300
+ roc_curve_kwargs: dict = {},
301
+ relplot_kwargs: dict = {},
302
+ ) -> Union[pd.DataFrame, None]:
303
+ """Return dataframe of ROC curve
304
+ Args:
305
+ show (bool, optional): Show the chart. Defaults to True.
306
+ save_path (str): path to save rendered chart.
307
+ pr_curve_kwargs (dict): keyword args of `eval_handler[...].model.roc_curve` method.
308
+ relplot_kwargs (dict): keyword args of `sns.relplot` except {data,x,y,hue,kind,col}. where x="fpr", y="tpr", hue="concept"
309
+ Returns:
310
+ None or pd.Dataframe, If models don't have ROC curve, return None
311
+ """
312
+ sns.set_palette("Paired")
313
+ outs, comparator = self._loop_eval_handlers("roc_curve", **roc_curve_kwargs)
314
+ all_dfs = []
315
+ for _, (df, anchor) in enumerate(zip(outs, comparator)):
316
+ df[self.comparator] = [anchor for _ in range(len(df))]
317
+ all_dfs.append(df)
318
+
319
+ if all_dfs:
320
+ all_dfs = pd.concat(all_dfs, axis=0)
321
+ if save_path or show:
322
+ relplot_kwargs = self._setup_default_lineplot(all_dfs, relplot_kwargs)
323
+ g = sns.relplot(
324
+ data=all_dfs,
325
+ x="fpr",
326
+ y="tpr",
327
+ hue='concept',
328
+ kind="line",
329
+ col=self.comparator,
330
+ **relplot_kwargs,
331
+ )
332
+ g.set_titles(col_template=str(self.comparator) + ':{col_name}', fontsize=5)
333
+ if show:
334
+ plt.show()
335
+ if save_path:
336
+ g.savefig(save_path)
337
+
338
+ return all_dfs if isinstance(all_dfs, pd.DataFrame) else None
339
+
340
+ def pr_plot(
341
+ self,
342
+ show=True,
343
+ save_path: str = None,
344
+ pr_curve_kwargs: dict = {},
345
+ relplot_kwargs: dict = {},
346
+ ) -> Union[pd.DataFrame, None]:
347
+ """Return dataframe of PR curve
348
+ Args:
349
+ show (bool, optional): Show the chart. Defaults to True.
350
+ save_path (str): path to save rendered chart.
351
+ pr_curve_kwargs (dict): keyword args of `eval_handler[...].model.pr_curve` method.
352
+ relplot_kwargs (dict): keyword args of `sns.relplot` except {data,x,y,hue,kind,col} where x="recall", y="precision", hue="concept"
353
+ Returns:
354
+ None or pd.Dataframe, If models don't have PR curve, return None
355
+ """
356
+ sns.set_palette("Paired")
357
+ outs, comparator = self._loop_eval_handlers("pr_curve", **pr_curve_kwargs)
358
+ all_dfs = []
359
+ for _, (df, anchor) in enumerate(zip(outs, comparator)):
360
+ df[self.comparator] = [anchor for _ in range(len(df))]
361
+ all_dfs.append(df)
362
+
363
+ if all_dfs:
364
+ all_dfs = pd.concat(all_dfs, axis=0)
365
+ if save_path or show:
366
+ relplot_kwargs = self._setup_default_lineplot(all_dfs, relplot_kwargs)
367
+ g = sns.relplot(
368
+ data=all_dfs,
369
+ x="recall",
370
+ y="precision",
371
+ hue='concept',
372
+ kind="line",
373
+ col=self.comparator,
374
+ **relplot_kwargs,
375
+ )
376
+ g.set_titles(col_template=str(self.comparator) + ':{col_name}', fontsize=5)
377
+ if show:
378
+ plt.show()
379
+ if save_path:
380
+ g.savefig(save_path)
381
+
382
+ return all_dfs if isinstance(all_dfs, pd.DataFrame) else None
383
+
384
+ def all(
385
+ self,
386
+ output_folder: str,
387
+ confidence_threshold: float = 0.5,
388
+ iou_threshold: float = 0.5,
389
+ overwrite: bool = False,
390
+ metric_kwargs: dict = {},
391
+ pr_plot_kwargs: dict = {},
392
+ roc_plot_kwargs: dict = {},
393
+ ):
394
+ """Run all comparison methods one by one:
395
+ - detailed_summary
396
+ - pr_curve (if applicable)
397
+ - pr_plot
398
+ - confusion_matrix (if applicable)
399
+ And save to output_folder
400
+
401
+ Args:
402
+ output_folder (str): path to output
403
+ confidence_threshold (float): confidence threshold, applicable for classification and detection. Default is 0.5.
404
+ iou_threshold (float): iou threshold, support in range(0.5, 1., step=0.1) applicable for detection.
405
+ overwrite (bool): overwrite result of output_folder.
406
+ metric_kwargs (dict): keyword args for `eval_handler[...].model.{method}`, except for {confidence_threshold, iou_threshold}.
407
+ roc_plot_kwargs (dict): for relplot_kwargs of `roc_curve_plot` method.
408
+ pr_plot_kwargs (dict): for relplot_kwargs of `pr_plot` method.
409
+ """
410
+ eval_type = get_eval_type(self.model_type)
411
+ area = metric_kwargs.pop("area", "all")
412
+ bypass_const = metric_kwargs.pop("bypass_const", False)
413
+
414
+ fname = f"conf-{confidence_threshold}"
415
+ if eval_type == EvalType.DETECTION:
416
+ fname = f"{fname}_iou-{iou_threshold}_area-{area}"
417
+
418
+ def join_root(*args):
419
+ return os.path.join(output_folder, *args)
420
+
421
+ output_folder = join_root(fname)
422
+ if os.path.exists(output_folder) and not overwrite:
423
+ raise RuntimeError(
424
+ f"{output_folder} exists. If you want to overwrite, set `overwrite=True`"
425
+ )
426
+
427
+ os.makedirs(output_folder, exist_ok=True)
428
+
429
+ logger.info("Making summary tables...")
430
+ dfs = self.detailed_summary(
431
+ confidence_threshold=confidence_threshold,
432
+ iou_threshold=iou_threshold,
433
+ area=area,
434
+ bypass_const=bypass_const,
435
+ )
436
+ if dfs is not None:
437
+ concept_df, total_df = dfs
438
+ concept_df.to_csv(join_root("concepts_summary.csv"))
439
+ total_df.to_csv(join_root("total_summary.csv"))
440
+
441
+ curve_metric_kwargs = dict(
442
+ confidence_threshold=confidence_threshold, iou_threshold=iou_threshold
443
+ )
444
+ curve_metric_kwargs.update(metric_kwargs)
445
+
446
+ self.roc_curve_plot(
447
+ show=False,
448
+ save_path=join_root("roc.jpg"),
449
+ roc_curve_kwargs=curve_metric_kwargs,
450
+ relplot_kwargs=roc_plot_kwargs,
451
+ )
452
+
453
+ self.pr_plot(
454
+ show=False,
455
+ save_path=join_root("pr.jpg"),
456
+ pr_curve_kwargs=curve_metric_kwargs,
457
+ relplot_kwargs=pr_plot_kwargs,
458
+ )
459
+
460
+ self.confusion_matrix(
461
+ show=False, save_path=join_root("confusion_matrix.jpg"), cm_kwargs=curve_metric_kwargs
462
+ )
463
+
464
+ logger.info(f"Done. Your outputs are saved at {output_folder}")