clarifai 11.3.0rc2__py3-none-any.whl → 11.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (300) hide show
  1. clarifai/__init__.py +1 -1
  2. clarifai/cli/__main__.py +1 -1
  3. clarifai/cli/base.py +144 -136
  4. clarifai/cli/compute_cluster.py +45 -31
  5. clarifai/cli/deployment.py +93 -76
  6. clarifai/cli/model.py +578 -180
  7. clarifai/cli/nodepool.py +100 -82
  8. clarifai/client/__init__.py +12 -2
  9. clarifai/client/app.py +973 -911
  10. clarifai/client/auth/helper.py +345 -342
  11. clarifai/client/auth/register.py +7 -7
  12. clarifai/client/auth/stub.py +107 -106
  13. clarifai/client/base.py +185 -178
  14. clarifai/client/compute_cluster.py +214 -180
  15. clarifai/client/dataset.py +793 -698
  16. clarifai/client/deployment.py +55 -50
  17. clarifai/client/input.py +1223 -1088
  18. clarifai/client/lister.py +47 -45
  19. clarifai/client/model.py +1939 -1717
  20. clarifai/client/model_client.py +525 -502
  21. clarifai/client/module.py +82 -73
  22. clarifai/client/nodepool.py +358 -213
  23. clarifai/client/runner.py +58 -0
  24. clarifai/client/search.py +342 -309
  25. clarifai/client/user.py +419 -414
  26. clarifai/client/workflow.py +294 -274
  27. clarifai/constants/dataset.py +11 -17
  28. clarifai/constants/model.py +8 -2
  29. clarifai/datasets/export/inputs_annotations.py +233 -217
  30. clarifai/datasets/upload/base.py +63 -51
  31. clarifai/datasets/upload/features.py +43 -38
  32. clarifai/datasets/upload/image.py +237 -207
  33. clarifai/datasets/upload/loaders/coco_captions.py +34 -32
  34. clarifai/datasets/upload/loaders/coco_detection.py +72 -65
  35. clarifai/datasets/upload/loaders/imagenet_classification.py +57 -53
  36. clarifai/datasets/upload/loaders/xview_detection.py +274 -132
  37. clarifai/datasets/upload/multimodal.py +55 -46
  38. clarifai/datasets/upload/text.py +55 -47
  39. clarifai/datasets/upload/utils.py +250 -234
  40. clarifai/errors.py +51 -50
  41. clarifai/models/api.py +260 -238
  42. clarifai/modules/css.py +50 -50
  43. clarifai/modules/pages.py +33 -33
  44. clarifai/rag/rag.py +312 -288
  45. clarifai/rag/utils.py +91 -84
  46. clarifai/runners/models/model_builder.py +906 -802
  47. clarifai/runners/models/model_class.py +370 -331
  48. clarifai/runners/models/model_run_locally.py +459 -419
  49. clarifai/runners/models/model_runner.py +170 -162
  50. clarifai/runners/models/model_servicer.py +78 -70
  51. clarifai/runners/server.py +111 -101
  52. clarifai/runners/utils/code_script.py +225 -187
  53. clarifai/runners/utils/const.py +4 -1
  54. clarifai/runners/utils/data_types/__init__.py +12 -0
  55. clarifai/runners/utils/data_types/data_types.py +598 -0
  56. clarifai/runners/utils/data_utils.py +387 -440
  57. clarifai/runners/utils/loader.py +247 -227
  58. clarifai/runners/utils/method_signatures.py +411 -386
  59. clarifai/runners/utils/openai_convertor.py +108 -109
  60. clarifai/runners/utils/serializers.py +175 -179
  61. clarifai/runners/utils/url_fetcher.py +35 -35
  62. clarifai/schema/search.py +56 -63
  63. clarifai/urls/helper.py +125 -102
  64. clarifai/utils/cli.py +129 -123
  65. clarifai/utils/config.py +127 -87
  66. clarifai/utils/constants.py +49 -0
  67. clarifai/utils/evaluation/helpers.py +503 -466
  68. clarifai/utils/evaluation/main.py +431 -393
  69. clarifai/utils/evaluation/testset_annotation_parser.py +154 -144
  70. clarifai/utils/logging.py +324 -306
  71. clarifai/utils/misc.py +60 -56
  72. clarifai/utils/model_train.py +165 -146
  73. clarifai/utils/protobuf.py +126 -103
  74. clarifai/versions.py +3 -1
  75. clarifai/workflows/export.py +48 -50
  76. clarifai/workflows/utils.py +39 -36
  77. clarifai/workflows/validate.py +55 -43
  78. {clarifai-11.3.0rc2.dist-info → clarifai-11.4.0.dist-info}/METADATA +16 -6
  79. clarifai-11.4.0.dist-info/RECORD +109 -0
  80. {clarifai-11.3.0rc2.dist-info → clarifai-11.4.0.dist-info}/WHEEL +1 -1
  81. clarifai/__pycache__/__init__.cpython-310.pyc +0 -0
  82. clarifai/__pycache__/__init__.cpython-311.pyc +0 -0
  83. clarifai/__pycache__/__init__.cpython-39.pyc +0 -0
  84. clarifai/__pycache__/errors.cpython-310.pyc +0 -0
  85. clarifai/__pycache__/errors.cpython-311.pyc +0 -0
  86. clarifai/__pycache__/versions.cpython-310.pyc +0 -0
  87. clarifai/__pycache__/versions.cpython-311.pyc +0 -0
  88. clarifai/cli/__pycache__/__init__.cpython-310.pyc +0 -0
  89. clarifai/cli/__pycache__/__init__.cpython-311.pyc +0 -0
  90. clarifai/cli/__pycache__/base.cpython-310.pyc +0 -0
  91. clarifai/cli/__pycache__/base.cpython-311.pyc +0 -0
  92. clarifai/cli/__pycache__/base_cli.cpython-310.pyc +0 -0
  93. clarifai/cli/__pycache__/compute_cluster.cpython-310.pyc +0 -0
  94. clarifai/cli/__pycache__/compute_cluster.cpython-311.pyc +0 -0
  95. clarifai/cli/__pycache__/deployment.cpython-310.pyc +0 -0
  96. clarifai/cli/__pycache__/deployment.cpython-311.pyc +0 -0
  97. clarifai/cli/__pycache__/model.cpython-310.pyc +0 -0
  98. clarifai/cli/__pycache__/model.cpython-311.pyc +0 -0
  99. clarifai/cli/__pycache__/model_cli.cpython-310.pyc +0 -0
  100. clarifai/cli/__pycache__/nodepool.cpython-310.pyc +0 -0
  101. clarifai/cli/__pycache__/nodepool.cpython-311.pyc +0 -0
  102. clarifai/client/__pycache__/__init__.cpython-310.pyc +0 -0
  103. clarifai/client/__pycache__/__init__.cpython-311.pyc +0 -0
  104. clarifai/client/__pycache__/__init__.cpython-39.pyc +0 -0
  105. clarifai/client/__pycache__/app.cpython-310.pyc +0 -0
  106. clarifai/client/__pycache__/app.cpython-311.pyc +0 -0
  107. clarifai/client/__pycache__/app.cpython-39.pyc +0 -0
  108. clarifai/client/__pycache__/base.cpython-310.pyc +0 -0
  109. clarifai/client/__pycache__/base.cpython-311.pyc +0 -0
  110. clarifai/client/__pycache__/compute_cluster.cpython-310.pyc +0 -0
  111. clarifai/client/__pycache__/compute_cluster.cpython-311.pyc +0 -0
  112. clarifai/client/__pycache__/dataset.cpython-310.pyc +0 -0
  113. clarifai/client/__pycache__/dataset.cpython-311.pyc +0 -0
  114. clarifai/client/__pycache__/deployment.cpython-310.pyc +0 -0
  115. clarifai/client/__pycache__/deployment.cpython-311.pyc +0 -0
  116. clarifai/client/__pycache__/input.cpython-310.pyc +0 -0
  117. clarifai/client/__pycache__/input.cpython-311.pyc +0 -0
  118. clarifai/client/__pycache__/lister.cpython-310.pyc +0 -0
  119. clarifai/client/__pycache__/lister.cpython-311.pyc +0 -0
  120. clarifai/client/__pycache__/model.cpython-310.pyc +0 -0
  121. clarifai/client/__pycache__/model.cpython-311.pyc +0 -0
  122. clarifai/client/__pycache__/module.cpython-310.pyc +0 -0
  123. clarifai/client/__pycache__/module.cpython-311.pyc +0 -0
  124. clarifai/client/__pycache__/nodepool.cpython-310.pyc +0 -0
  125. clarifai/client/__pycache__/nodepool.cpython-311.pyc +0 -0
  126. clarifai/client/__pycache__/search.cpython-310.pyc +0 -0
  127. clarifai/client/__pycache__/search.cpython-311.pyc +0 -0
  128. clarifai/client/__pycache__/user.cpython-310.pyc +0 -0
  129. clarifai/client/__pycache__/user.cpython-311.pyc +0 -0
  130. clarifai/client/__pycache__/workflow.cpython-310.pyc +0 -0
  131. clarifai/client/__pycache__/workflow.cpython-311.pyc +0 -0
  132. clarifai/client/auth/__pycache__/__init__.cpython-310.pyc +0 -0
  133. clarifai/client/auth/__pycache__/__init__.cpython-311.pyc +0 -0
  134. clarifai/client/auth/__pycache__/helper.cpython-310.pyc +0 -0
  135. clarifai/client/auth/__pycache__/helper.cpython-311.pyc +0 -0
  136. clarifai/client/auth/__pycache__/register.cpython-310.pyc +0 -0
  137. clarifai/client/auth/__pycache__/register.cpython-311.pyc +0 -0
  138. clarifai/client/auth/__pycache__/stub.cpython-310.pyc +0 -0
  139. clarifai/client/auth/__pycache__/stub.cpython-311.pyc +0 -0
  140. clarifai/client/cli/__init__.py +0 -0
  141. clarifai/client/cli/__pycache__/__init__.cpython-310.pyc +0 -0
  142. clarifai/client/cli/__pycache__/base_cli.cpython-310.pyc +0 -0
  143. clarifai/client/cli/__pycache__/model_cli.cpython-310.pyc +0 -0
  144. clarifai/client/cli/base_cli.py +0 -88
  145. clarifai/client/cli/model_cli.py +0 -29
  146. clarifai/constants/__pycache__/base.cpython-310.pyc +0 -0
  147. clarifai/constants/__pycache__/base.cpython-311.pyc +0 -0
  148. clarifai/constants/__pycache__/dataset.cpython-310.pyc +0 -0
  149. clarifai/constants/__pycache__/dataset.cpython-311.pyc +0 -0
  150. clarifai/constants/__pycache__/input.cpython-310.pyc +0 -0
  151. clarifai/constants/__pycache__/input.cpython-311.pyc +0 -0
  152. clarifai/constants/__pycache__/model.cpython-310.pyc +0 -0
  153. clarifai/constants/__pycache__/model.cpython-311.pyc +0 -0
  154. clarifai/constants/__pycache__/rag.cpython-310.pyc +0 -0
  155. clarifai/constants/__pycache__/rag.cpython-311.pyc +0 -0
  156. clarifai/constants/__pycache__/search.cpython-310.pyc +0 -0
  157. clarifai/constants/__pycache__/search.cpython-311.pyc +0 -0
  158. clarifai/constants/__pycache__/workflow.cpython-310.pyc +0 -0
  159. clarifai/constants/__pycache__/workflow.cpython-311.pyc +0 -0
  160. clarifai/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
  161. clarifai/datasets/__pycache__/__init__.cpython-311.pyc +0 -0
  162. clarifai/datasets/__pycache__/__init__.cpython-39.pyc +0 -0
  163. clarifai/datasets/export/__pycache__/__init__.cpython-310.pyc +0 -0
  164. clarifai/datasets/export/__pycache__/__init__.cpython-311.pyc +0 -0
  165. clarifai/datasets/export/__pycache__/__init__.cpython-39.pyc +0 -0
  166. clarifai/datasets/export/__pycache__/inputs_annotations.cpython-310.pyc +0 -0
  167. clarifai/datasets/export/__pycache__/inputs_annotations.cpython-311.pyc +0 -0
  168. clarifai/datasets/upload/__pycache__/__init__.cpython-310.pyc +0 -0
  169. clarifai/datasets/upload/__pycache__/__init__.cpython-311.pyc +0 -0
  170. clarifai/datasets/upload/__pycache__/__init__.cpython-39.pyc +0 -0
  171. clarifai/datasets/upload/__pycache__/base.cpython-310.pyc +0 -0
  172. clarifai/datasets/upload/__pycache__/base.cpython-311.pyc +0 -0
  173. clarifai/datasets/upload/__pycache__/features.cpython-310.pyc +0 -0
  174. clarifai/datasets/upload/__pycache__/features.cpython-311.pyc +0 -0
  175. clarifai/datasets/upload/__pycache__/image.cpython-310.pyc +0 -0
  176. clarifai/datasets/upload/__pycache__/image.cpython-311.pyc +0 -0
  177. clarifai/datasets/upload/__pycache__/multimodal.cpython-310.pyc +0 -0
  178. clarifai/datasets/upload/__pycache__/multimodal.cpython-311.pyc +0 -0
  179. clarifai/datasets/upload/__pycache__/text.cpython-310.pyc +0 -0
  180. clarifai/datasets/upload/__pycache__/text.cpython-311.pyc +0 -0
  181. clarifai/datasets/upload/__pycache__/utils.cpython-310.pyc +0 -0
  182. clarifai/datasets/upload/__pycache__/utils.cpython-311.pyc +0 -0
  183. clarifai/datasets/upload/loaders/__pycache__/__init__.cpython-311.pyc +0 -0
  184. clarifai/datasets/upload/loaders/__pycache__/__init__.cpython-39.pyc +0 -0
  185. clarifai/datasets/upload/loaders/__pycache__/coco_detection.cpython-311.pyc +0 -0
  186. clarifai/datasets/upload/loaders/__pycache__/imagenet_classification.cpython-311.pyc +0 -0
  187. clarifai/models/__pycache__/__init__.cpython-39.pyc +0 -0
  188. clarifai/modules/__pycache__/__init__.cpython-39.pyc +0 -0
  189. clarifai/rag/__pycache__/__init__.cpython-310.pyc +0 -0
  190. clarifai/rag/__pycache__/__init__.cpython-311.pyc +0 -0
  191. clarifai/rag/__pycache__/__init__.cpython-39.pyc +0 -0
  192. clarifai/rag/__pycache__/rag.cpython-310.pyc +0 -0
  193. clarifai/rag/__pycache__/rag.cpython-311.pyc +0 -0
  194. clarifai/rag/__pycache__/rag.cpython-39.pyc +0 -0
  195. clarifai/rag/__pycache__/utils.cpython-310.pyc +0 -0
  196. clarifai/rag/__pycache__/utils.cpython-311.pyc +0 -0
  197. clarifai/runners/__pycache__/__init__.cpython-310.pyc +0 -0
  198. clarifai/runners/__pycache__/__init__.cpython-311.pyc +0 -0
  199. clarifai/runners/__pycache__/__init__.cpython-39.pyc +0 -0
  200. clarifai/runners/dockerfile_template/Dockerfile.cpu.template +0 -31
  201. clarifai/runners/dockerfile_template/Dockerfile.cuda.template +0 -42
  202. clarifai/runners/dockerfile_template/Dockerfile.nim +0 -71
  203. clarifai/runners/models/__pycache__/__init__.cpython-310.pyc +0 -0
  204. clarifai/runners/models/__pycache__/__init__.cpython-311.pyc +0 -0
  205. clarifai/runners/models/__pycache__/__init__.cpython-39.pyc +0 -0
  206. clarifai/runners/models/__pycache__/base_typed_model.cpython-310.pyc +0 -0
  207. clarifai/runners/models/__pycache__/base_typed_model.cpython-311.pyc +0 -0
  208. clarifai/runners/models/__pycache__/base_typed_model.cpython-39.pyc +0 -0
  209. clarifai/runners/models/__pycache__/model_builder.cpython-311.pyc +0 -0
  210. clarifai/runners/models/__pycache__/model_class.cpython-310.pyc +0 -0
  211. clarifai/runners/models/__pycache__/model_class.cpython-311.pyc +0 -0
  212. clarifai/runners/models/__pycache__/model_run_locally.cpython-310-pytest-7.1.2.pyc +0 -0
  213. clarifai/runners/models/__pycache__/model_run_locally.cpython-310.pyc +0 -0
  214. clarifai/runners/models/__pycache__/model_run_locally.cpython-311.pyc +0 -0
  215. clarifai/runners/models/__pycache__/model_runner.cpython-310.pyc +0 -0
  216. clarifai/runners/models/__pycache__/model_runner.cpython-311.pyc +0 -0
  217. clarifai/runners/models/__pycache__/model_upload.cpython-310.pyc +0 -0
  218. clarifai/runners/models/base_typed_model.py +0 -238
  219. clarifai/runners/models/model_class_refract.py +0 -80
  220. clarifai/runners/models/model_upload.py +0 -607
  221. clarifai/runners/models/temp.py +0 -25
  222. clarifai/runners/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  223. clarifai/runners/utils/__pycache__/__init__.cpython-311.pyc +0 -0
  224. clarifai/runners/utils/__pycache__/__init__.cpython-38.pyc +0 -0
  225. clarifai/runners/utils/__pycache__/__init__.cpython-39.pyc +0 -0
  226. clarifai/runners/utils/__pycache__/buffered_stream.cpython-310.pyc +0 -0
  227. clarifai/runners/utils/__pycache__/buffered_stream.cpython-38.pyc +0 -0
  228. clarifai/runners/utils/__pycache__/buffered_stream.cpython-39.pyc +0 -0
  229. clarifai/runners/utils/__pycache__/const.cpython-310.pyc +0 -0
  230. clarifai/runners/utils/__pycache__/const.cpython-311.pyc +0 -0
  231. clarifai/runners/utils/__pycache__/constants.cpython-310.pyc +0 -0
  232. clarifai/runners/utils/__pycache__/constants.cpython-38.pyc +0 -0
  233. clarifai/runners/utils/__pycache__/constants.cpython-39.pyc +0 -0
  234. clarifai/runners/utils/__pycache__/data_handler.cpython-310.pyc +0 -0
  235. clarifai/runners/utils/__pycache__/data_handler.cpython-311.pyc +0 -0
  236. clarifai/runners/utils/__pycache__/data_handler.cpython-38.pyc +0 -0
  237. clarifai/runners/utils/__pycache__/data_handler.cpython-39.pyc +0 -0
  238. clarifai/runners/utils/__pycache__/data_utils.cpython-310.pyc +0 -0
  239. clarifai/runners/utils/__pycache__/data_utils.cpython-311.pyc +0 -0
  240. clarifai/runners/utils/__pycache__/data_utils.cpython-38.pyc +0 -0
  241. clarifai/runners/utils/__pycache__/data_utils.cpython-39.pyc +0 -0
  242. clarifai/runners/utils/__pycache__/grpc_server.cpython-310.pyc +0 -0
  243. clarifai/runners/utils/__pycache__/grpc_server.cpython-38.pyc +0 -0
  244. clarifai/runners/utils/__pycache__/grpc_server.cpython-39.pyc +0 -0
  245. clarifai/runners/utils/__pycache__/health.cpython-310.pyc +0 -0
  246. clarifai/runners/utils/__pycache__/health.cpython-38.pyc +0 -0
  247. clarifai/runners/utils/__pycache__/health.cpython-39.pyc +0 -0
  248. clarifai/runners/utils/__pycache__/loader.cpython-310.pyc +0 -0
  249. clarifai/runners/utils/__pycache__/loader.cpython-311.pyc +0 -0
  250. clarifai/runners/utils/__pycache__/logging.cpython-310.pyc +0 -0
  251. clarifai/runners/utils/__pycache__/logging.cpython-38.pyc +0 -0
  252. clarifai/runners/utils/__pycache__/logging.cpython-39.pyc +0 -0
  253. clarifai/runners/utils/__pycache__/stream_source.cpython-310.pyc +0 -0
  254. clarifai/runners/utils/__pycache__/stream_source.cpython-39.pyc +0 -0
  255. clarifai/runners/utils/__pycache__/url_fetcher.cpython-310.pyc +0 -0
  256. clarifai/runners/utils/__pycache__/url_fetcher.cpython-311.pyc +0 -0
  257. clarifai/runners/utils/__pycache__/url_fetcher.cpython-38.pyc +0 -0
  258. clarifai/runners/utils/__pycache__/url_fetcher.cpython-39.pyc +0 -0
  259. clarifai/runners/utils/data_handler.py +0 -231
  260. clarifai/runners/utils/data_handler_refract.py +0 -213
  261. clarifai/runners/utils/data_types.py +0 -469
  262. clarifai/runners/utils/logger.py +0 -0
  263. clarifai/runners/utils/openai_format.py +0 -87
  264. clarifai/schema/__pycache__/search.cpython-310.pyc +0 -0
  265. clarifai/schema/__pycache__/search.cpython-311.pyc +0 -0
  266. clarifai/urls/__pycache__/helper.cpython-310.pyc +0 -0
  267. clarifai/urls/__pycache__/helper.cpython-311.pyc +0 -0
  268. clarifai/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  269. clarifai/utils/__pycache__/__init__.cpython-311.pyc +0 -0
  270. clarifai/utils/__pycache__/__init__.cpython-39.pyc +0 -0
  271. clarifai/utils/__pycache__/cli.cpython-310.pyc +0 -0
  272. clarifai/utils/__pycache__/cli.cpython-311.pyc +0 -0
  273. clarifai/utils/__pycache__/config.cpython-311.pyc +0 -0
  274. clarifai/utils/__pycache__/constants.cpython-310.pyc +0 -0
  275. clarifai/utils/__pycache__/constants.cpython-311.pyc +0 -0
  276. clarifai/utils/__pycache__/logging.cpython-310.pyc +0 -0
  277. clarifai/utils/__pycache__/logging.cpython-311.pyc +0 -0
  278. clarifai/utils/__pycache__/misc.cpython-310.pyc +0 -0
  279. clarifai/utils/__pycache__/misc.cpython-311.pyc +0 -0
  280. clarifai/utils/__pycache__/model_train.cpython-310.pyc +0 -0
  281. clarifai/utils/__pycache__/model_train.cpython-311.pyc +0 -0
  282. clarifai/utils/__pycache__/protobuf.cpython-311.pyc +0 -0
  283. clarifai/utils/evaluation/__pycache__/__init__.cpython-311.pyc +0 -0
  284. clarifai/utils/evaluation/__pycache__/__init__.cpython-39.pyc +0 -0
  285. clarifai/utils/evaluation/__pycache__/helpers.cpython-311.pyc +0 -0
  286. clarifai/utils/evaluation/__pycache__/main.cpython-311.pyc +0 -0
  287. clarifai/utils/evaluation/__pycache__/main.cpython-39.pyc +0 -0
  288. clarifai/workflows/__pycache__/__init__.cpython-310.pyc +0 -0
  289. clarifai/workflows/__pycache__/__init__.cpython-311.pyc +0 -0
  290. clarifai/workflows/__pycache__/__init__.cpython-39.pyc +0 -0
  291. clarifai/workflows/__pycache__/export.cpython-310.pyc +0 -0
  292. clarifai/workflows/__pycache__/export.cpython-311.pyc +0 -0
  293. clarifai/workflows/__pycache__/utils.cpython-310.pyc +0 -0
  294. clarifai/workflows/__pycache__/utils.cpython-311.pyc +0 -0
  295. clarifai/workflows/__pycache__/validate.cpython-310.pyc +0 -0
  296. clarifai/workflows/__pycache__/validate.cpython-311.pyc +0 -0
  297. clarifai-11.3.0rc2.dist-info/RECORD +0 -322
  298. {clarifai-11.3.0rc2.dist-info → clarifai-11.4.0.dist-info}/entry_points.txt +0 -0
  299. {clarifai-11.3.0rc2.dist-info → clarifai-11.4.0.dist-info/licenses}/LICENSE +0 -0
  300. {clarifai-11.3.0rc2.dist-info → clarifai-11.4.0.dist-info}/top_level.txt +0 -0
clarifai/client/model.py CHANGED
@@ -22,1751 +22,1973 @@ from clarifai.client.input import Inputs
22
22
  from clarifai.client.lister import Lister
23
23
  from clarifai.client.model_client import ModelClient
24
24
  from clarifai.client.nodepool import Nodepool
25
- from clarifai.constants.model import (CHUNK_SIZE, MAX_CHUNK_SIZE, MAX_RANGE_SIZE, MIN_CHUNK_SIZE,
26
- MIN_RANGE_SIZE, MODEL_EXPORT_TIMEOUT, RANGE_SIZE,
27
- TRAINABLE_MODEL_TYPES)
25
+ from clarifai.constants.model import (
26
+ CHUNK_SIZE,
27
+ MAX_CHUNK_SIZE,
28
+ MAX_RANGE_SIZE,
29
+ MIN_CHUNK_SIZE,
30
+ MIN_RANGE_SIZE,
31
+ MODEL_EXPORT_TIMEOUT,
32
+ RANGE_SIZE,
33
+ TRAINABLE_MODEL_TYPES,
34
+ )
28
35
  from clarifai.errors import UserError
29
36
  from clarifai.urls.helper import ClarifaiUrlHelper
30
37
  from clarifai.utils.logging import logger
31
38
  from clarifai.utils.misc import BackoffIterator
32
- from clarifai.utils.model_train import (find_and_replace_key, params_parser,
33
- response_to_model_params, response_to_param_info,
34
- response_to_templates)
39
+ from clarifai.utils.model_train import (
40
+ find_and_replace_key,
41
+ params_parser,
42
+ response_to_model_params,
43
+ response_to_param_info,
44
+ response_to_templates,
45
+ )
35
46
  from clarifai.utils.protobuf import dict_to_protobuf
47
+
36
48
  MAX_SIZE_PER_STREAM = int(89_128_960) # 85GiB
37
49
  MIN_CHUNK_FOR_UPLOAD_FILE = int(5_242_880) # 5MiB
38
50
  MAX_CHUNK_FOR_UPLOAD_FILE = int(5_242_880_000) # 5GiB
39
51
 
40
52
 
41
53
  class Model(Lister, BaseClient):
42
- """Model is a class that provides access to Clarifai API endpoints related to Model information."""
43
-
44
- def __init__(self,
45
- url: str = None,
46
- model_id: str = None,
47
- model_version: Dict = {'id': ""},
48
- base_url: str = "https://api.clarifai.com",
49
- pat: str = None,
50
- token: str = None,
51
- root_certificates_path: str = None,
52
- compute_cluster_id: str = None,
53
- nodepool_id: str = None,
54
- deployment_id: str = None,
55
- **kwargs):
56
- """Initializes a Model object.
57
-
58
- Args:
59
- url (str): The URL to initialize the model object.
60
- model_id (str): The Model ID to interact with.
61
- model_version (dict): The Model Version to interact with.
62
- base_url (str): Base API url. Default "https://api.clarifai.com"
63
- pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
64
- token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
65
- root_certificates_path (str): Path to the SSL root certificates file, used to establish secure gRPC connections.
66
- **kwargs: Additional keyword arguments to be passed to the Model.
67
- """
68
- if url and model_id:
69
- raise UserError("You can only specify one of url or model_id.")
70
- if not url and not model_id:
71
- raise UserError("You must specify one of url or model_id.")
72
- if url:
73
- user_id, app_id, _, model_id, model_version_id = ClarifaiUrlHelper.split_clarifai_url(url)
74
- model_version = {'id': model_version_id}
75
- kwargs = {'user_id': user_id, 'app_id': app_id}
76
-
77
- self.kwargs = {**kwargs, 'id': model_id, 'model_version': model_version, }
78
- self.model_info = resources_pb2.Model()
79
- dict_to_protobuf(self.model_info, self.kwargs)
80
-
81
- self.logger = logger
82
- self.training_params = {}
83
- self.input_types = None
84
- self._client = None
85
- self._added_methods = False
86
- self._set_runner_selector(
87
- compute_cluster_id=compute_cluster_id,
88
- nodepool_id=nodepool_id,
89
- deployment_id=deployment_id,
90
- user_id=self.user_id, # FIXME the deployment's user_id can be different than the model's.
91
- )
92
- BaseClient.__init__(
54
+ """Model is a class that provides access to Clarifai API endpoints related to Model information."""
55
+
56
+ def __init__(
93
57
  self,
94
- user_id=self.user_id,
95
- app_id=self.app_id,
96
- base=base_url,
97
- pat=pat,
98
- token=token,
99
- root_certificates_path=root_certificates_path)
100
- Lister.__init__(self)
101
-
102
- def list_training_templates(self) -> List[str]:
103
- """Lists all the training templates for the model type.
104
-
105
- Returns:
106
- templates (List): List of training templates for the model type.
107
-
108
- Example:
109
- >>> from clarifai.client.model import Model
110
- >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
111
- >>> print(model.list_training_templates())
112
- """
113
- if not self.model_info.model_type_id:
114
- self.load_info()
115
- if self.model_info.model_type_id not in TRAINABLE_MODEL_TYPES:
116
- raise UserError(f"Model type {self.model_info.model_type_id} is not trainable")
117
- request = service_pb2.ListModelTypesRequest(user_app_id=self.user_app_id,)
118
- response = self._grpc_request(self.STUB.ListModelTypes, request)
119
- if response.status.code != status_code_pb2.SUCCESS:
120
- raise Exception(response.status)
121
- templates = response_to_templates(
122
- response=response, model_type_id=self.model_info.model_type_id)
123
-
124
- return templates
125
-
126
- def get_params(self, template: str = None, save_to: str = 'params.yaml') -> Dict[str, Any]:
127
- """Returns the model params for the model type and yaml file.
128
-
129
- Args:
130
- template (str): The template to use for the model type.
131
- yaml_file (str): The yaml file to save the model params.
132
-
133
- Returns:
134
- params (Dict): Dictionary of model params for the model type.
135
-
136
- Example:
137
- >>> from clarifai.client.model import Model
138
- >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
139
- >>> model_params = model.get_params(template='template', yaml_file='model_params.yaml')
140
- """
141
- if not self.model_info.model_type_id:
142
- self.load_info()
143
- if self.model_info.model_type_id not in TRAINABLE_MODEL_TYPES:
144
- raise UserError(f"Model type {self.model_info.model_type_id} is not trainable")
145
- if template is None and self.model_info.model_type_id not in [
146
- "clusterer", "embedding-classifier"
147
- ]:
148
- raise UserError(
149
- f"Template should be provided for {self.model_info.model_type_id} model type")
150
- if template is not None and self.model_info.model_type_id in [
151
- "clusterer", "embedding-classifier"
152
- ]:
153
- raise UserError(
154
- f"Template should not be provided for {self.model_info.model_type_id} model type")
155
-
156
- request = service_pb2.ListModelTypesRequest(user_app_id=self.user_app_id,)
157
- response = self._grpc_request(self.STUB.ListModelTypes, request)
158
- if response.status.code != status_code_pb2.SUCCESS:
159
- raise Exception(response.status)
160
- params = response_to_model_params(
161
- response=response, model_type_id=self.model_info.model_type_id, template=template)
162
- # yaml file
163
- assert save_to.endswith('.yaml'), "File extension should be .yaml"
164
- with open(save_to, 'w') as f:
165
- yaml.dump(params, f, default_flow_style=False, sort_keys=False)
166
- # updating the global model params
167
- self.training_params.update(params)
168
-
169
- return params
170
-
171
- def update_params(self, **kwargs) -> None:
172
- """Updates the model params for the model.
173
-
174
- Args:
175
- **kwargs: model params to update.
176
-
177
- Example:
178
- >>> from clarifai.client.model import Model
179
- >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
180
- >>> model_params = model.get_params(template='template', yaml_file='model_params.yaml')
181
- >>> model.update_params(batch_size = 8, dataset_version = 'dataset_version_id')
182
- """
183
- if self.model_info.model_type_id not in TRAINABLE_MODEL_TYPES:
184
- raise UserError(f"Model type {self.model_info.model_type_id} is not trainable")
185
- if len(self.training_params) == 0:
186
- raise UserError(
187
- f"Run 'model.get_params' to get the params for the {self.model_info.model_type_id} model type"
188
- )
189
- # getting all the keys in nested dictionary
190
- all_keys = [key for key in self.training_params.keys()] + [
191
- key for key in self.training_params.values() if isinstance(key, dict) for key in key
192
- ]
193
- # checking if the given params are valid
194
- if not set(kwargs.keys()).issubset(all_keys):
195
- raise UserError("Invalid params")
196
- # updating the global model params
197
- for key, value in kwargs.items():
198
- find_and_replace_key(self.training_params, key, value)
199
-
200
- def get_param_info(self, param: str) -> Dict[str, Any]:
201
- """Returns the param info for the param.
202
-
203
- Args:
204
- param (str): The param to get the info for.
205
-
206
- Returns:
207
- param_info (Dict): Dictionary of model param info for the param.
208
-
209
- Example:
210
- >>> from clarifai.client.model import Model
211
- >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
212
- >>> model_params = model.get_params(template='template', yaml_file='model_params.yaml')
213
- >>> model.get_param_info('param')
214
- """
215
- if self.model_info.model_type_id not in TRAINABLE_MODEL_TYPES:
216
- raise UserError(f"Model type {self.model_info.model_type_id} is not trainable")
217
- if len(self.training_params) == 0:
218
- raise UserError(
219
- f"Run 'model.get_params' to get the params for the {self.model_info.model_type_id} model type"
220
- )
221
-
222
- all_keys = [key for key in self.training_params.keys()] + [
223
- key for key in self.training_params.values() if isinstance(key, dict) for key in key
224
- ]
225
- if param not in all_keys:
226
- raise UserError(f"Invalid param: '{param}' for model type '{self.model_info.model_type_id}'")
227
- template = self.training_params['train_params']['template'] if 'template' in all_keys else None
228
-
229
- request = service_pb2.ListModelTypesRequest(user_app_id=self.user_app_id,)
230
- response = self._grpc_request(self.STUB.ListModelTypes, request)
231
- if response.status.code != status_code_pb2.SUCCESS:
232
- raise Exception(response.status)
233
- param_info = response_to_param_info(
234
- response=response,
235
- model_type_id=self.model_info.model_type_id,
236
- param=param,
237
- template=template)
238
-
239
- return param_info
240
-
241
- def train(self, yaml_file: str = None) -> str:
242
- """Trains the model based on the given yaml file or model params.
243
-
244
- Args:
245
- yaml_file (str): The yaml file for the model params.
246
-
247
- Returns:
248
- model_version_id (str): The model version ID for the model.
249
-
250
- Example:
251
- >>> from clarifai.client.model import Model
252
- >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
253
- >>> model_params = model.get_params(template='template', yaml_file='model_params.yaml')
254
- >>> model.train('model_params.yaml')
255
- """
256
- if not self.model_info.model_type_id:
257
- self.load_info()
258
- if self.model_info.model_type_id not in TRAINABLE_MODEL_TYPES:
259
- raise UserError(f"Model type {self.model_info.model_type_id} is not trainable")
260
- if not yaml_file and len(self.training_params) == 0:
261
- raise UserError("Provide yaml file or run 'model.get_params()'")
262
-
263
- if yaml_file:
264
- with open(yaml_file, 'r') as file:
265
- params_dict = yaml.safe_load(file)
266
- else:
267
- params_dict = self.training_params
268
- # getting all the concepts for the model type
269
- if self.model_info.model_type_id not in ["clusterer", "text-to-text"]:
270
- concepts = self._list_concepts()
271
- train_dict = params_parser(params_dict, concepts)
272
- request = service_pb2.PostModelVersionsRequest(
273
- user_app_id=self.user_app_id,
274
- model_id=self.id,
275
- model_versions=[resources_pb2.ModelVersion(**train_dict)])
276
- response = self._grpc_request(self.STUB.PostModelVersions, request)
277
- if response.status.code != status_code_pb2.SUCCESS:
278
- raise Exception(response.status)
279
- self.logger.info("\nModel Training Started\n%s", response.status)
280
-
281
- return response.model.model_version.id
282
-
283
- def training_status(self, version_id: str = None, training_logs: bool = False) -> Dict[str, str]:
284
- """Get the training status for the model version. Also stores training logs
285
-
286
- Args:
287
- version_id (str): The version ID to get the training status for.
288
- training_logs (bool): Whether to save the training logs in a file.
289
-
290
- Returns:
291
- training_status (Dict): Dictionary of training status for the model version.
292
-
293
- Example:
294
- >>> from clarifai.client.model import Model
295
- >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
296
- >>> model.training_status(version_id='version_id',training_logs=True)
297
- """
298
- if not version_id and not self.model_info.model_version.id:
299
- raise UserError(
300
- "Model version ID is missing. Please provide a `model_version` with a valid `id` as an argument or as a URL in the following format: '{user_id}/{app_id}/models/{your_model_id}/model_version_id/{your_version_model_id}' when initializing."
301
- )
302
-
303
- self.load_info()
304
- if self.model_info.model_type_id not in TRAINABLE_MODEL_TYPES:
305
- raise UserError(f"Model type {self.model_info.model_type_id} is not trainable")
306
-
307
- if training_logs:
308
- try:
309
- if self.model_info.model_version.train_log:
310
- log_response = requests.get(self.model_info.model_version.train_log)
311
- log_response.raise_for_status() # Check for any HTTP errors
312
- with open(version_id + '.log', 'wb') as file:
313
- for chunk in log_response.iter_content(chunk_size=4096): # 4KB
314
- file.write(chunk)
315
- self.logger.info(f"\nTraining logs are saving in '{version_id+'.log'}' file")
316
-
317
- except requests.exceptions.RequestException as e:
318
- raise Exception(f"An error occurred while getting training logs: {e}")
319
-
320
- return self.model_info.model_version.status
321
-
322
- def delete_version(self, version_id: str) -> None:
323
- """Deletes a model version for the Model.
324
-
325
- Args:
326
- version_id (str): The version ID to delete.
327
-
328
- Example:
329
- >>> from clarifai.client.model import Model
330
- >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
331
- >>> model.delete_version(version_id='version_id')
332
- """
333
- request = service_pb2.DeleteModelVersionRequest(
334
- user_app_id=self.user_app_id, model_id=self.id, version_id=version_id)
335
-
336
- response = self._grpc_request(self.STUB.DeleteModelVersion, request)
337
- if response.status.code != status_code_pb2.SUCCESS:
338
- raise Exception(response.status)
339
- self.logger.info("\nModel Version Deleted\n%s", response.status)
340
-
341
- def create_version(self, **kwargs) -> 'Model':
342
- """Creates a model version for the Model.
343
-
344
- Args:
345
- **kwargs: Additional keyword arguments to be passed to Model Version.
346
- - description (str): The description of the model version.
347
- - concepts (list[Concept]): The concepts to associate with the model version.
348
- - output_info (resources_pb2.OutputInfo(): The output info to associate with the model version.
349
-
350
- Returns:
351
- Model: A Model object for the specified model ID.
352
-
353
- Example:
354
- >>> from clarifai.client.model import Model
355
- >>> model = Model("url")
356
- or
357
- >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
358
- >>> model_version = model.create_version(description='model_version_description')
359
- """
360
- if self.model_info.model_type_id in TRAINABLE_MODEL_TYPES:
361
- raise UserError(
362
- f"{self.model_info.model_type_id} is a trainable model type. Use 'model.train()' to train the model"
363
- )
364
-
365
- request = service_pb2.PostModelVersionsRequest(
366
- user_app_id=self.user_app_id,
367
- model_id=self.id,
368
- model_versions=[resources_pb2.ModelVersion(**kwargs)])
369
-
370
- response = self._grpc_request(self.STUB.PostModelVersions, request)
371
- if response.status.code != status_code_pb2.SUCCESS:
372
- raise Exception(response.status)
373
- self.logger.info("\nModel Version created\n%s", response.status)
374
-
375
- kwargs.update({'app_id': self.app_id, 'user_id': self.user_id})
376
- dict_response = MessageToDict(response, preserving_proto_field_name=True)
377
- kwargs = self.process_response_keys(dict_response['model'], 'model')
378
-
379
- return Model(base_url=self.base, pat=self.pat, token=self.token, **kwargs)
380
-
381
- def list_versions(self, page_no: int = None,
382
- per_page: int = None) -> Generator['Model', None, None]:
383
- """Lists all the versions for the model.
384
-
385
- Args:
386
- page_no (int): The page number to list.
387
- per_page (int): The number of items per page.
388
-
389
- Yields:
390
- Model: Model objects for the versions of the model.
391
-
392
- Example:
393
- >>> from clarifai.client.model import Model
394
- >>> model = Model("url") # Example URL: https://clarifai.com/clarifai/main/models/general-image-recognition
395
- or
396
- >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
397
- >>> all_model_versions = list(model.list_versions())
398
-
399
- Note:
400
- Defaults to 16 per page if page_no is specified and per_page is not specified.
401
- If both page_no and per_page are None, then lists all the resources.
402
- """
403
- request_data = dict(
404
- user_app_id=self.user_app_id,
405
- model_id=self.id,
406
- )
407
- all_model_versions_info = self.list_pages_generator(
408
- self.STUB.ListModelVersions,
409
- service_pb2.ListModelVersionsRequest,
410
- request_data,
411
- per_page=per_page,
412
- page_no=page_no)
413
-
414
- for model_version_info in all_model_versions_info:
415
- model_version_info['id'] = model_version_info['model_version_id']
416
- del model_version_info['model_version_id']
417
- try:
418
- del model_version_info['train_info']['dataset']['version']['metrics']
419
- except KeyError:
420
- pass
421
- yield Model.from_auth_helper(
422
- auth=self.auth_helper,
423
- model_id=self.id,
424
- **dict(self.kwargs, model_version=model_version_info))
425
-
426
- @property
427
- def client(self):
428
- if self._client is None:
429
- request_template = service_pb2.PostModelOutputsRequest(
430
- user_app_id=self.user_app_id,
431
- model_id=self.id,
432
- version_id=self.model_version.id,
433
- model=self.model_info,
434
- runner_selector=self._runner_selector,
435
- )
436
- self._client = ModelClient(self.STUB, request_template=request_template)
437
- return self._client
438
-
439
- def predict(self, *args, **kwargs):
440
- """
441
- Calls the model's predict() method with the given arguments.
442
-
443
- If passed in request_pb2.PostModelOutputsRequest values, will send the model the raw
444
- protos directly for compatibility with previous versions of the SDK.
445
- """
446
-
447
- inputs = None
448
- if 'inputs' in kwargs:
449
- inputs = kwargs['inputs']
450
- elif args:
451
- inputs = args[0]
452
- if inputs and isinstance(inputs, list) and isinstance(inputs[0], resources_pb2.Input):
453
- assert len(args) <= 1, "Cannot pass in raw protos and additional arguments at the same time."
454
- inference_params = kwargs.get('inference_params', {})
455
- output_config = kwargs.get('output_config', {})
456
- return self.client._predict_by_proto(
457
- inputs=inputs, inference_params=inference_params, output_config=output_config)
458
-
459
- return self.client.predict(*args, **kwargs)
460
-
461
- def __getattr__(self, name):
462
- try:
463
- return getattr(self.model_info, name)
464
- except AttributeError:
465
- pass
466
- if not self._added_methods:
467
- # fetch and set all the model methods
468
- self._added_methods = True
469
- self.client.fetch()
470
- for method_name in self.client._method_signatures.keys():
471
- if not hasattr(self, method_name):
472
- setattr(self, method_name, getattr(self.client, method_name))
473
- if hasattr(self.client, name):
474
- return getattr(self.client, name)
475
- raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
476
-
477
- def _check_predict_input_type(self, input_type: str) -> None:
478
- """Checks if the input type is valid for the model.
479
-
480
- Args:
481
- input_type (str): The input type to check.
482
- Returns:
483
- None
484
- """
485
- if not input_type:
486
- self.load_input_types()
487
- if len(self.input_types) > 1:
488
- raise UserError(
489
- "Model has multiple input types. Please use model.predict() for this multi-modal model."
58
+ url: str = None,
59
+ model_id: str = None,
60
+ model_version: Dict = {'id': ""},
61
+ base_url: str = "https://api.clarifai.com",
62
+ pat: str = None,
63
+ token: str = None,
64
+ root_certificates_path: str = None,
65
+ compute_cluster_id: str = None,
66
+ nodepool_id: str = None,
67
+ deployment_id: str = None,
68
+ **kwargs,
69
+ ):
70
+ """Initializes a Model object.
71
+
72
+ Args:
73
+ url (str): The URL to initialize the model object.
74
+ model_id (str): The Model ID to interact with.
75
+ model_version (dict): The Model Version to interact with.
76
+ base_url (str): Base API url. Default "https://api.clarifai.com"
77
+ pat (str): A personal access token for authentication. Can be set as env var CLARIFAI_PAT
78
+ token (str): A session token for authentication. Accepts either a session token or a pat. Can be set as env var CLARIFAI_SESSION_TOKEN
79
+ root_certificates_path (str): Path to the SSL root certificates file, used to establish secure gRPC connections.
80
+ **kwargs: Additional keyword arguments to be passed to the Model.
81
+ """
82
+ if url and model_id:
83
+ raise UserError("You can only specify one of url or model_id.")
84
+ if not url and not model_id:
85
+ raise UserError("You must specify one of url or model_id.")
86
+ if url:
87
+ user_id, app_id, _, model_id, model_version_id = ClarifaiUrlHelper.split_clarifai_url(
88
+ url
89
+ )
90
+ model_version = {'id': model_version_id}
91
+ kwargs = {'user_id': user_id, 'app_id': app_id}
92
+
93
+ self.kwargs = {
94
+ **kwargs,
95
+ 'id': model_id,
96
+ 'model_version': model_version,
97
+ }
98
+ self.model_info = resources_pb2.Model()
99
+ dict_to_protobuf(self.model_info, self.kwargs)
100
+
101
+ self.logger = logger
102
+ self.training_params = {}
103
+ self.input_types = None
104
+ self._client = None
105
+ self._added_methods = False
106
+ self._set_runner_selector(
107
+ compute_cluster_id=compute_cluster_id,
108
+ nodepool_id=nodepool_id,
109
+ deployment_id=deployment_id,
110
+ user_id=self.user_id, # FIXME the deployment's user_id can be different than the model's.
490
111
  )
491
- else:
492
- self.input_types = [input_type]
493
- if self.input_types[0] not in {'image', 'text', 'video', 'audio'}:
494
- raise UserError(
495
- f"Got input type {input_type} but expected one of image, text, video, audio.")
112
+ BaseClient.__init__(
113
+ self,
114
+ user_id=self.user_id,
115
+ app_id=self.app_id,
116
+ base=base_url,
117
+ pat=pat,
118
+ token=token,
119
+ root_certificates_path=root_certificates_path,
120
+ )
121
+ Lister.__init__(self)
496
122
 
497
- def load_input_types(self) -> None:
498
- """Loads the input types for the model.
123
+ @classmethod
124
+ def from_current_context(cls, **kwargs) -> 'Model':
125
+ from clarifai.utils.config import Config
499
126
 
500
- Returns:
501
- None
127
+ current = Config.from_yaml().current
502
128
 
503
- Example:
504
- >>> from clarifai.client.model import Model
505
- >>> model = Model("url") # Example URL: https://clarifai.com/clarifai/main/models/general-image-recognition
506
- or
507
- >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
508
- >>> model.load_input_types()
509
- """
510
- if self.input_types:
511
- return self.input_types
512
- if self.model_info.model_type_id == "":
513
- self.load_info()
514
- request = service_pb2.GetModelTypeRequest(
515
- user_app_id=self.user_app_id,
516
- model_type_id=self.model_info.model_type_id,
517
- )
518
- response = self._grpc_request(self.STUB.GetModelType, request)
519
- if response.status.code != status_code_pb2.SUCCESS:
520
- raise Exception(response.status)
521
- self.input_types = response.model_type.input_fields
522
-
523
- def _set_runner_selector(self,
524
- compute_cluster_id: str = None,
525
- nodepool_id: str = None,
526
- deployment_id: str = None,
527
- user_id: str = None):
528
- runner_selector = None
529
- if deployment_id and (compute_cluster_id or nodepool_id):
530
- raise UserError(
531
- "You can only specify one of deployment_id or compute_cluster_id and nodepool_id.")
532
-
533
- if deployment_id:
534
- if not user_id and not os.environ.get('CLARIFAI_USER_ID'):
535
- raise UserError(
536
- "User ID is required for model prediction with deployment ID, please provide user_id in the method call."
129
+ # set the current context to env vars.
130
+ current.set_to_env()
131
+
132
+ url = f"https://clarifai.com/{current.user_id}/{current.app_id}/models/{current.model_id}"
133
+
134
+ # construct the Model object.
135
+ kwargs = {}
136
+ try:
137
+ kwargs['deployment_id'] = current.deployment_id
138
+ except AttributeError:
139
+ try:
140
+ kwargs['compute_cluster_id'] = current.compute_cluster_id
141
+ kwargs['nodepool_id'] = current.nodepool_id
142
+ except AttributeError:
143
+ pass
144
+
145
+ return Model(url, base_url=current.api_base, pat=current.pat, **kwargs)
146
+
147
+ def list_training_templates(self) -> List[str]:
148
+ """Lists all the training templates for the model type.
149
+
150
+ Returns:
151
+ templates (List): List of training templates for the model type.
152
+
153
+ Example:
154
+ >>> from clarifai.client.model import Model
155
+ >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
156
+ >>> print(model.list_training_templates())
157
+ """
158
+ if not self.model_info.model_type_id:
159
+ self.load_info()
160
+ if self.model_info.model_type_id not in TRAINABLE_MODEL_TYPES:
161
+ raise UserError(f"Model type {self.model_info.model_type_id} is not trainable")
162
+ request = service_pb2.ListModelTypesRequest(
163
+ user_app_id=self.user_app_id,
537
164
  )
538
- if not user_id:
539
- user_id = os.environ.get('CLARIFAI_USER_ID')
540
- runner_selector = Deployment.get_runner_selector(
541
- user_id=user_id, deployment_id=deployment_id)
542
- elif compute_cluster_id and nodepool_id:
543
- if not user_id and not os.environ.get('CLARIFAI_USER_ID'):
544
- raise UserError(
545
- "User ID is required for model prediction with compute cluster ID and nodepool ID, please provide user_id in the method call."
165
+ response = self._grpc_request(self.STUB.ListModelTypes, request)
166
+ if response.status.code != status_code_pb2.SUCCESS:
167
+ raise Exception(response.status)
168
+ templates = response_to_templates(
169
+ response=response, model_type_id=self.model_info.model_type_id
546
170
  )
547
- if not user_id:
548
- user_id = os.environ.get('CLARIFAI_USER_ID')
549
- runner_selector = Nodepool.get_runner_selector(
550
- user_id=user_id, compute_cluster_id=compute_cluster_id, nodepool_id=nodepool_id)
551
-
552
- # set the runner selector
553
- self._runner_selector = runner_selector
554
-
555
- def predict_by_filepath(self,
556
- filepath: str,
557
- input_type: str = None,
558
- inference_params: Dict = {},
559
- output_config: Dict = {}):
560
- """Predicts the model based on the given filepath.
561
-
562
- Args:
563
- filepath (str): The filepath to predict.
564
- input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio.
565
- inference_params (dict): The inference params to override.
566
- output_config (dict): The output config to override.
567
- min_value (float): The minimum value of the prediction confidence to filter.
568
- max_concepts (int): The maximum number of concepts to return.
569
- select_concepts (list[Concept]): The concepts to select.
570
-
571
- Example:
572
- >>> from clarifai.client.model import Model
573
- >>> model = Model("url") # Example URL: https://clarifai.com/clarifai/main/models/general-image-recognition
574
- or
575
- >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
576
- >>> model_prediction = model.predict_by_filepath('/path/to/image.jpg')
577
- >>> model_prediction = model.predict_by_filepath('/path/to/text.txt')
578
- """
579
- if not os.path.isfile(filepath):
580
- raise UserError('Invalid filepath.')
581
-
582
- with open(filepath, "rb") as f:
583
- file_bytes = f.read()
584
-
585
- return self.predict_by_bytes(file_bytes, input_type, inference_params, output_config)
586
-
587
- def predict_by_bytes(self,
588
- input_bytes: bytes,
589
- input_type: str = None,
590
- inference_params: Dict = {},
591
- output_config: Dict = {}):
592
- """Predicts the model based on the given bytes.
593
-
594
- Args:
595
- input_bytes (bytes): File Bytes to predict on.
596
- input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio.
597
- inference_params (dict): The inference params to override.
598
- output_config (dict): The output config to override.
599
- min_value (float): The minimum value of the prediction confidence to filter.
600
- max_concepts (int): The maximum number of concepts to return.
601
- select_concepts (list[Concept]): The concepts to select.
602
-
603
- Example:
604
- >>> from clarifai.client.model import Model
605
- >>> model = Model("https://clarifai.com/openai/chat-completion/models/GPT-4")
606
- >>> model_prediction = model.predict_by_bytes(b'Write a tweet on future of AI',
607
- inference_params=dict(temperature=str(0.7), max_tokens=30)))
608
- """
609
- self._check_predict_input_type(input_type)
610
-
611
- if self.input_types[0] == "image":
612
- input_proto = Inputs.get_input_from_bytes("", image_bytes=input_bytes)
613
- elif self.input_types[0] == "text":
614
- input_proto = Inputs.get_input_from_bytes("", text_bytes=input_bytes)
615
- elif self.input_types[0] == "video":
616
- input_proto = Inputs.get_input_from_bytes("", video_bytes=input_bytes)
617
- elif self.input_types[0] == "audio":
618
- input_proto = Inputs.get_input_from_bytes("", audio_bytes=input_bytes)
619
-
620
- return self.predict(
621
- inputs=[input_proto], inference_params=inference_params, output_config=output_config)
622
-
623
- def predict_by_url(self,
624
- url: str,
625
- input_type: str = None,
626
- inference_params: Dict = {},
627
- output_config: Dict = {}):
628
- """Predicts the model based on the given URL.
629
-
630
- Args:
631
- url (str): The URL to predict.
632
- input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio'.
633
- inference_params (dict): The inference params to override.
634
- output_config (dict): The output config to override.
635
- min_value (float): The minimum value of the prediction confidence to filter.
636
- max_concepts (int): The maximum number of concepts to return.
637
- select_concepts (list[Concept]): The concepts to select.
638
-
639
- Example:
640
- >>> from clarifai.client.model import Model
641
- >>> model = Model("url") # Example URL: https://clarifai.com/clarifai/main/models/general-image-recognition
642
- or
643
- >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
644
- >>> model_prediction = model.predict_by_url('url')
645
- """
646
- self._check_predict_input_type(input_type)
647
-
648
- if self.input_types[0] == "image":
649
- input_proto = Inputs.get_input_from_url("", image_url=url)
650
- elif self.input_types[0] == "text":
651
- input_proto = Inputs.get_input_from_url("", text_url=url)
652
- elif self.input_types[0] == "video":
653
- input_proto = Inputs.get_input_from_url("", video_url=url)
654
- elif self.input_types[0] == "audio":
655
- input_proto = Inputs.get_input_from_url("", audio_url=url)
656
-
657
- return self.predict(
658
- inputs=[input_proto], inference_params=inference_params, output_config=output_config)
659
-
660
- def generate(self, *args, **kwargs):
661
- """
662
- Calls the model's generate() method with the given arguments.
663
-
664
- If passed in request_pb2.PostModelOutputsRequest values, will send the model the raw
665
- protos directly for compatibility with previous versions of the SDK.
666
- """
667
-
668
- inputs = None
669
- if 'inputs' in kwargs:
670
- inputs = kwargs['inputs']
671
- elif args:
672
- inputs = args[0]
673
- if inputs and isinstance(inputs, list) and isinstance(inputs[0], resources_pb2.Input):
674
- assert len(args) <= 1, "Cannot pass in raw protos and additional arguments at the same time."
675
- inference_params = kwargs.get('inference_params', {})
676
- output_config = kwargs.get('output_config', {})
677
- return self.client._generate_by_proto(
678
- inputs=inputs, inference_params=inference_params, output_config=output_config)
679
-
680
- return self.client.generate(*args, **kwargs)
681
-
682
- def generate_by_filepath(self,
683
- filepath: str,
684
- input_type: str = None,
685
- inference_params: Dict = {},
686
- output_config: Dict = {}):
687
- """Generate the stream output on model based on the given filepath.
688
-
689
- Args:
690
- filepath (str): The filepath to predict.
691
- input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio.
692
- inference_params (dict): The inference params to override.
693
- output_config (dict): The output config to override.
694
- min_value (float): The minimum value of the prediction confidence to filter.
695
- max_concepts (int): The maximum number of concepts to return.
696
- select_concepts (list[Concept]): The concepts to select.
697
-
698
- Example:
699
- >>> from clarifai.client.model import Model
700
- >>> model = Model("url") # Example URL: https://clarifai.com/clarifai/main/models/general-image-recognition
701
- or
702
- >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
703
- >>> stream_response = model.generate_by_filepath('/path/to/image.jpg', 'image', deployment_id='deployment_id')
704
- >>> list_stream_response = [response for response in stream_response]
705
- """
706
- if not os.path.isfile(filepath):
707
- raise UserError('Invalid filepath.')
708
-
709
- with open(filepath, "rb") as f:
710
- file_bytes = f.read()
711
-
712
- return self.generate_by_bytes(
713
- input_bytes=file_bytes,
714
- input_type=input_type,
715
- inference_params=inference_params,
716
- output_config=output_config)
717
-
718
- def generate_by_bytes(self,
719
- input_bytes: bytes,
720
- input_type: str = None,
721
- inference_params: Dict = {},
722
- output_config: Dict = {}):
723
- """Generate the stream output on model based on the given bytes.
724
-
725
- Args:
726
- input_bytes (bytes): File Bytes to predict on.
727
- input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio.
728
- inference_params (dict): The inference params to override.
729
- output_config (dict): The output config to override.
730
- min_value (float): The minimum value of the prediction confidence to filter.
731
- max_concepts (int): The maximum number of concepts to return.
732
- select_concepts (list[Concept]): The concepts to select.
733
-
734
- Example:
735
- >>> from clarifai.client.model import Model
736
- >>> model = Model("https://clarifai.com/openai/chat-completion/models/GPT-4")
737
- >>> stream_response = model.generate_by_bytes(b'Write a tweet on future of AI',
738
- deployment_id='deployment_id',
739
- inference_params=dict(temperature=str(0.7), max_tokens=30)))
740
- >>> list_stream_response = [response for response in stream_response]
741
- """
742
- self._check_predict_input_type(input_type)
743
-
744
- if self.input_types[0] == "image":
745
- input_proto = Inputs.get_input_from_bytes("", image_bytes=input_bytes)
746
- elif self.input_types[0] == "text":
747
- input_proto = Inputs.get_input_from_bytes("", text_bytes=input_bytes)
748
- elif self.input_types[0] == "video":
749
- input_proto = Inputs.get_input_from_bytes("", video_bytes=input_bytes)
750
- elif self.input_types[0] == "audio":
751
- input_proto = Inputs.get_input_from_bytes("", audio_bytes=input_bytes)
752
-
753
- return self.generate(
754
- inputs=[input_proto], inference_params=inference_params, output_config=output_config)
755
-
756
- def generate_by_url(self,
757
- url: str,
758
- input_type: str = None,
759
- inference_params: Dict = {},
760
- output_config: Dict = {}):
761
- """Generate the stream output on model based on the given URL.
762
-
763
- Args:
764
- url (str): The URL to predict.
765
- input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio.
766
- inference_params (dict): The inference params to override.
767
- output_config (dict): The output config to override.
768
- min_value (float): The minimum value of the prediction confidence to filter.
769
- max_concepts (int): The maximum number of concepts to return.
770
- select_concepts (list[Concept]): The concepts to select.
771
-
772
- Example:
773
- >>> from clarifai.client.model import Model
774
- >>> model = Model("url") # Example URL: https://clarifai.com/clarifai/main/models/general-image-recognition
775
- or
776
- >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
777
- >>> stream_response = model.generate_by_url('url', deployment_id='deployment_id')
778
- >>> list_stream_response = [response for response in stream_response]
779
- """
780
- self._check_predict_input_type(input_type)
781
-
782
- if self.input_types[0] == "image":
783
- input_proto = Inputs.get_input_from_url("", image_url=url)
784
- elif self.input_types[0] == "text":
785
- input_proto = Inputs.get_input_from_url("", text_url=url)
786
- elif self.input_types[0] == "video":
787
- input_proto = Inputs.get_input_from_url("", video_url=url)
788
- elif self.input_types[0] == "audio":
789
- input_proto = Inputs.get_input_from_url("", audio_url=url)
790
-
791
- return self.generate(
792
- inputs=[input_proto], inference_params=inference_params, output_config=output_config)
793
-
794
- def stream(self, *args, **kwargs):
795
- """
796
- Calls the model's stream() method with the given arguments.
797
-
798
- If passed in request_pb2.PostModelOutputsRequest values, will send the model the raw
799
- protos directly for compatibility with previous versions of the SDK.
800
- """
801
-
802
- use_proto_call = False
803
- inputs = None
804
- if 'inputs' in kwargs:
805
- inputs = kwargs['inputs']
806
- elif args:
807
- inputs = args[0]
808
- if inputs and isinstance(inputs, Iterable):
809
- inputs_iter = inputs
810
- try:
811
- peek = next(inputs_iter)
812
- except StopIteration:
813
- pass
814
- else:
815
- use_proto_call = (peek and isinstance(peek, list) and
816
- isinstance(peek[0], resources_pb2.Input))
817
- # put back the peeked value
818
- if inputs_iter is inputs:
819
- inputs = itertools.chain([peek], inputs_iter)
820
- if 'inputs' in kwargs:
821
- kwargs['inputs'] = inputs
822
- else:
823
- args = (inputs,) + args[1:]
824
-
825
- if use_proto_call:
826
- assert len(args) <= 1, "Cannot pass in raw protos and additional arguments at the same time."
827
- inference_params = kwargs.get('inference_params', {})
828
- output_config = kwargs.get('output_config', {})
829
- return self.client._stream_by_proto(
830
- inputs=inputs, inference_params=inference_params, output_config=output_config)
831
-
832
- return self.client.stream(*args, **kwargs)
833
-
834
- def stream_by_filepath(self,
835
- filepath: str,
836
- input_type: str = None,
837
- inference_params: Dict = {},
838
- output_config: Dict = {}):
839
- """Stream the model output based on the given filepath.
840
-
841
- Args:
842
- filepath (str): The filepath to predict.
843
- input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio.
844
- inference_params (dict): The inference params to override.
845
- output_config (dict): The output config to override.
846
- min_value (float): The minimum value of the prediction confidence to filter.
847
- max_concepts (int): The maximum number of concepts to return.
848
- select_concepts (list[Concept]): The concepts to select.
849
-
850
- Example:
851
- >>> from clarifai.client.model import Model
852
- >>> model = Model("url")
853
- >>> stream_response = model.stream_by_filepath('/path/to/image.jpg', deployment_id='deployment_id')
854
- >>> list_stream_response = [response for response in stream_response]
855
- """
856
- if not os.path.isfile(filepath):
857
- raise UserError('Invalid filepath.')
858
-
859
- with open(filepath, "rb") as f:
860
- file_bytes = f.read()
861
-
862
- return self.stream_by_bytes(
863
- input_bytes_iterator=iter([file_bytes]),
864
- input_type=input_type,
865
- inference_params=inference_params,
866
- output_config=output_config)
867
-
868
- def stream_by_bytes(self,
869
- input_bytes_iterator: Iterator[bytes],
870
- input_type: str = None,
871
- inference_params: Dict = {},
872
- output_config: Dict = {}):
873
- """Stream the model output based on the given bytes.
874
-
875
- Args:
876
- input_bytes_iterator (Iterator[bytes]): Iterator of file bytes to predict on.
877
- input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio.
878
- inference_params (dict): The inference params to override.
879
- output_config (dict): The output config to override.
880
- min_value (float): The minimum value of the prediction confidence to filter.
881
- max_concepts (int): The maximum number of concepts to return.
882
- select_concepts (list[Concept]): The concepts to select.
883
-
884
- Example:
885
- >>> from clarifai.client.model import Model
886
- >>> model = Model("https://clarifai.com/openai/chat-completion/models/GPT-4")
887
- >>> stream_response = model.stream_by_bytes(iter([b'Write a tweet on future of AI']),
888
- deployment_id='deployment_id',
889
- inference_params=dict(temperature=str(0.7), max_tokens=30)))
890
- >>> list_stream_response = [response for response in stream_response]
891
- """
892
- self._check_predict_input_type(input_type)
893
-
894
- def input_generator():
895
- for input_bytes in input_bytes_iterator:
171
+
172
+ return templates
173
+
174
+ def get_params(self, template: str = None, save_to: str = 'params.yaml') -> Dict[str, Any]:
175
+ """Returns the model params for the model type and yaml file.
176
+
177
+ Args:
178
+ template (str): The template to use for the model type.
179
+ yaml_file (str): The yaml file to save the model params.
180
+
181
+ Returns:
182
+ params (Dict): Dictionary of model params for the model type.
183
+
184
+ Example:
185
+ >>> from clarifai.client.model import Model
186
+ >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
187
+ >>> model_params = model.get_params(template='template', yaml_file='model_params.yaml')
188
+ """
189
+ if not self.model_info.model_type_id:
190
+ self.load_info()
191
+ if self.model_info.model_type_id not in TRAINABLE_MODEL_TYPES:
192
+ raise UserError(f"Model type {self.model_info.model_type_id} is not trainable")
193
+ if template is None and self.model_info.model_type_id not in [
194
+ "clusterer",
195
+ "embedding-classifier",
196
+ ]:
197
+ raise UserError(
198
+ f"Template should be provided for {self.model_info.model_type_id} model type"
199
+ )
200
+ if template is not None and self.model_info.model_type_id in [
201
+ "clusterer",
202
+ "embedding-classifier",
203
+ ]:
204
+ raise UserError(
205
+ f"Template should not be provided for {self.model_info.model_type_id} model type"
206
+ )
207
+
208
+ request = service_pb2.ListModelTypesRequest(
209
+ user_app_id=self.user_app_id,
210
+ )
211
+ response = self._grpc_request(self.STUB.ListModelTypes, request)
212
+ if response.status.code != status_code_pb2.SUCCESS:
213
+ raise Exception(response.status)
214
+ params = response_to_model_params(
215
+ response=response, model_type_id=self.model_info.model_type_id, template=template
216
+ )
217
+ # yaml file
218
+ assert save_to.endswith('.yaml'), "File extension should be .yaml"
219
+ with open(save_to, 'w') as f:
220
+ yaml.dump(params, f, default_flow_style=False, sort_keys=False)
221
+ # updating the global model params
222
+ self.training_params.update(params)
223
+
224
+ return params
225
+
226
+ def update_params(self, **kwargs) -> None:
227
+ """Updates the model params for the model.
228
+
229
+ Args:
230
+ **kwargs: model params to update.
231
+
232
+ Example:
233
+ >>> from clarifai.client.model import Model
234
+ >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
235
+ >>> model_params = model.get_params(template='template', yaml_file='model_params.yaml')
236
+ >>> model.update_params(batch_size = 8, dataset_version = 'dataset_version_id')
237
+ """
238
+ if self.model_info.model_type_id not in TRAINABLE_MODEL_TYPES:
239
+ raise UserError(f"Model type {self.model_info.model_type_id} is not trainable")
240
+ if len(self.training_params) == 0:
241
+ raise UserError(
242
+ f"Run 'model.get_params' to get the params for the {self.model_info.model_type_id} model type"
243
+ )
244
+ # getting all the keys in nested dictionary
245
+ all_keys = [key for key in self.training_params.keys()] + [
246
+ key for key in self.training_params.values() if isinstance(key, dict) for key in key
247
+ ]
248
+ # checking if the given params are valid
249
+ if not set(kwargs.keys()).issubset(all_keys):
250
+ raise UserError("Invalid params")
251
+ # updating the global model params
252
+ for key, value in kwargs.items():
253
+ find_and_replace_key(self.training_params, key, value)
254
+
255
+ def get_param_info(self, param: str) -> Dict[str, Any]:
256
+ """Returns the param info for the param.
257
+
258
+ Args:
259
+ param (str): The param to get the info for.
260
+
261
+ Returns:
262
+ param_info (Dict): Dictionary of model param info for the param.
263
+
264
+ Example:
265
+ >>> from clarifai.client.model import Model
266
+ >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
267
+ >>> model_params = model.get_params(template='template', yaml_file='model_params.yaml')
268
+ >>> model.get_param_info('param')
269
+ """
270
+ if self.model_info.model_type_id not in TRAINABLE_MODEL_TYPES:
271
+ raise UserError(f"Model type {self.model_info.model_type_id} is not trainable")
272
+ if len(self.training_params) == 0:
273
+ raise UserError(
274
+ f"Run 'model.get_params' to get the params for the {self.model_info.model_type_id} model type"
275
+ )
276
+
277
+ all_keys = [key for key in self.training_params.keys()] + [
278
+ key for key in self.training_params.values() if isinstance(key, dict) for key in key
279
+ ]
280
+ if param not in all_keys:
281
+ raise UserError(
282
+ f"Invalid param: '{param}' for model type '{self.model_info.model_type_id}'"
283
+ )
284
+ template = (
285
+ self.training_params['train_params']['template'] if 'template' in all_keys else None
286
+ )
287
+
288
+ request = service_pb2.ListModelTypesRequest(
289
+ user_app_id=self.user_app_id,
290
+ )
291
+ response = self._grpc_request(self.STUB.ListModelTypes, request)
292
+ if response.status.code != status_code_pb2.SUCCESS:
293
+ raise Exception(response.status)
294
+ param_info = response_to_param_info(
295
+ response=response,
296
+ model_type_id=self.model_info.model_type_id,
297
+ param=param,
298
+ template=template,
299
+ )
300
+
301
+ return param_info
302
+
303
+ def train(self, yaml_file: str = None) -> str:
304
+ """Trains the model based on the given yaml file or model params.
305
+
306
+ Args:
307
+ yaml_file (str): The yaml file for the model params.
308
+
309
+ Returns:
310
+ model_version_id (str): The model version ID for the model.
311
+
312
+ Example:
313
+ >>> from clarifai.client.model import Model
314
+ >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
315
+ >>> model_params = model.get_params(template='template', yaml_file='model_params.yaml')
316
+ >>> model.train('model_params.yaml')
317
+ """
318
+ if not self.model_info.model_type_id:
319
+ self.load_info()
320
+ if self.model_info.model_type_id not in TRAINABLE_MODEL_TYPES:
321
+ raise UserError(f"Model type {self.model_info.model_type_id} is not trainable")
322
+ if not yaml_file and len(self.training_params) == 0:
323
+ raise UserError("Provide yaml file or run 'model.get_params()'")
324
+
325
+ if yaml_file:
326
+ with open(yaml_file, 'r') as file:
327
+ params_dict = yaml.safe_load(file)
328
+ else:
329
+ params_dict = self.training_params
330
+ # getting all the concepts for the model type
331
+ if self.model_info.model_type_id not in ["clusterer", "text-to-text"]:
332
+ concepts = self._list_concepts()
333
+ train_dict = params_parser(params_dict, concepts)
334
+ request = service_pb2.PostModelVersionsRequest(
335
+ user_app_id=self.user_app_id,
336
+ model_id=self.id,
337
+ model_versions=[resources_pb2.ModelVersion(**train_dict)],
338
+ )
339
+ response = self._grpc_request(self.STUB.PostModelVersions, request)
340
+ if response.status.code != status_code_pb2.SUCCESS:
341
+ raise Exception(response.status)
342
+ self.logger.info("\nModel Training Started\n%s", response.status)
343
+
344
+ return response.model.model_version.id
345
+
346
+ def training_status(
347
+ self, version_id: str = None, training_logs: bool = False
348
+ ) -> Dict[str, str]:
349
+ """Get the training status for the model version. Also stores training logs
350
+
351
+ Args:
352
+ version_id (str): The version ID to get the training status for.
353
+ training_logs (bool): Whether to save the training logs in a file.
354
+
355
+ Returns:
356
+ training_status (Dict): Dictionary of training status for the model version.
357
+
358
+ Example:
359
+ >>> from clarifai.client.model import Model
360
+ >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
361
+ >>> model.training_status(version_id='version_id',training_logs=True)
362
+ """
363
+ if not version_id and not self.model_info.model_version.id:
364
+ raise UserError(
365
+ "Model version ID is missing. Please provide a `model_version` with a valid `id` as an argument or as a URL in the following format: '{user_id}/{app_id}/models/{your_model_id}/model_version_id/{your_version_model_id}' when initializing."
366
+ )
367
+
368
+ self.load_info()
369
+ if self.model_info.model_type_id not in TRAINABLE_MODEL_TYPES:
370
+ raise UserError(f"Model type {self.model_info.model_type_id} is not trainable")
371
+
372
+ if training_logs:
373
+ try:
374
+ if self.model_info.model_version.train_log:
375
+ log_response = requests.get(self.model_info.model_version.train_log)
376
+ log_response.raise_for_status() # Check for any HTTP errors
377
+ with open(version_id + '.log', 'wb') as file:
378
+ for chunk in log_response.iter_content(chunk_size=4096): # 4KB
379
+ file.write(chunk)
380
+ self.logger.info(f"\nTraining logs are saving in '{version_id + '.log'}' file")
381
+
382
+ except requests.exceptions.RequestException as e:
383
+ raise Exception(f"An error occurred while getting training logs: {e}")
384
+
385
+ return self.model_info.model_version.status
386
+
387
+ def delete_version(self, version_id: str) -> None:
388
+ """Deletes a model version for the Model.
389
+
390
+ Args:
391
+ version_id (str): The version ID to delete.
392
+
393
+ Example:
394
+ >>> from clarifai.client.model import Model
395
+ >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
396
+ >>> model.delete_version(version_id='version_id')
397
+ """
398
+ request = service_pb2.DeleteModelVersionRequest(
399
+ user_app_id=self.user_app_id, model_id=self.id, version_id=version_id
400
+ )
401
+
402
+ response = self._grpc_request(self.STUB.DeleteModelVersion, request)
403
+ if response.status.code != status_code_pb2.SUCCESS:
404
+ raise Exception(response.status)
405
+ self.logger.info("\nModel Version Deleted\n%s", response.status)
406
+
407
+ def create_version(self, **kwargs) -> 'Model':
408
+ """Creates a model version for the Model.
409
+
410
+ Args:
411
+ **kwargs: Additional keyword arguments to be passed to Model Version.
412
+ - description (str): The description of the model version.
413
+ - concepts (list[Concept]): The concepts to associate with the model version.
414
+ - output_info (resources_pb2.OutputInfo(): The output info to associate with the model version.
415
+
416
+ Returns:
417
+ Model: A Model object for the specified model ID.
418
+
419
+ Example:
420
+ >>> from clarifai.client.model import Model
421
+ >>> model = Model("url")
422
+ or
423
+ >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
424
+ >>> model_version = model.create_version(description='model_version_description')
425
+ """
426
+ if self.model_info.model_type_id in TRAINABLE_MODEL_TYPES:
427
+ if 'pretrained_model_config' not in kwargs:
428
+ raise UserError(
429
+ f"{self.model_info.model_type_id} is a trainable model type. Use 'model.train()' to train the model"
430
+ )
431
+
432
+ request = service_pb2.PostModelVersionsRequest(
433
+ user_app_id=self.user_app_id,
434
+ model_id=self.id,
435
+ model_versions=[resources_pb2.ModelVersion(**kwargs)],
436
+ )
437
+
438
+ response = self._grpc_request(self.STUB.PostModelVersions, request)
439
+ if response.status.code != status_code_pb2.SUCCESS:
440
+ raise Exception(response.status)
441
+ self.logger.info("\nModel Version created\n%s", response.status)
442
+
443
+ kwargs.update({'app_id': self.app_id, 'user_id': self.user_id})
444
+ dict_response = MessageToDict(response, preserving_proto_field_name=True)
445
+ kwargs = self.process_response_keys(dict_response['model'], 'model')
446
+
447
+ return Model(base_url=self.base, pat=self.pat, token=self.token, **kwargs)
448
+
449
+ def list_versions(
450
+ self, page_no: int = None, per_page: int = None
451
+ ) -> Generator['Model', None, None]:
452
+ """Lists all the versions for the model.
453
+
454
+ Args:
455
+ page_no (int): The page number to list.
456
+ per_page (int): The number of items per page.
457
+
458
+ Yields:
459
+ Model: Model objects for the versions of the model.
460
+
461
+ Example:
462
+ >>> from clarifai.client.model import Model
463
+ >>> model = Model("url") # Example URL: https://clarifai.com/clarifai/main/models/general-image-recognition
464
+ or
465
+ >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
466
+ >>> all_model_versions = list(model.list_versions())
467
+
468
+ Note:
469
+ Defaults to 16 per page if page_no is specified and per_page is not specified.
470
+ If both page_no and per_page are None, then lists all the resources.
471
+ """
472
+ request_data = dict(
473
+ user_app_id=self.user_app_id,
474
+ model_id=self.id,
475
+ )
476
+ all_model_versions_info = self.list_pages_generator(
477
+ self.STUB.ListModelVersions,
478
+ service_pb2.ListModelVersionsRequest,
479
+ request_data,
480
+ per_page=per_page,
481
+ page_no=page_no,
482
+ )
483
+
484
+ for model_version_info in all_model_versions_info:
485
+ model_version_info['id'] = model_version_info['model_version_id']
486
+ del model_version_info['model_version_id']
487
+ try:
488
+ del model_version_info['train_info']['dataset']['version']['metrics']
489
+ except KeyError:
490
+ pass
491
+ yield Model.from_auth_helper(
492
+ auth=self.auth_helper,
493
+ model_id=self.id,
494
+ **dict(self.kwargs, model_version=model_version_info),
495
+ )
496
+
497
+ @property
498
+ def client(self):
499
+ if self._client is None:
500
+ request_template = service_pb2.PostModelOutputsRequest(
501
+ user_app_id=self.user_app_id,
502
+ model_id=self.id,
503
+ version_id=self.model_version.id,
504
+ model=self.model_info,
505
+ runner_selector=self._runner_selector,
506
+ )
507
+ self._client = ModelClient(self.STUB, request_template=request_template)
508
+ return self._client
509
+
510
+ def predict(self, *args, **kwargs):
511
+ """
512
+ Calls the model's predict() method with the given arguments.
513
+
514
+ If passed in request_pb2.PostModelOutputsRequest values, will send the model the raw
515
+ protos directly for compatibility with previous versions of the SDK.
516
+ """
517
+
518
+ inputs = None
519
+ if 'inputs' in kwargs:
520
+ inputs = kwargs['inputs']
521
+ elif args:
522
+ inputs = args[0]
523
+ if inputs and isinstance(inputs, list) and isinstance(inputs[0], resources_pb2.Input):
524
+ assert len(args) <= 1, (
525
+ "Cannot pass in raw protos and additional arguments at the same time."
526
+ )
527
+ inference_params = kwargs.get('inference_params', {})
528
+ output_config = kwargs.get('output_config', {})
529
+ return self.client._predict_by_proto(
530
+ inputs=inputs, inference_params=inference_params, output_config=output_config
531
+ )
532
+
533
+ return self.client.predict(*args, **kwargs)
534
+
535
+ def __getattr__(self, name):
536
+ try:
537
+ return getattr(self.model_info, name)
538
+ except AttributeError:
539
+ pass
540
+ if not self._added_methods:
541
+ # fetch and set all the model methods
542
+ self._added_methods = True
543
+ self.client.fetch()
544
+ for method_name in self.client._method_signatures.keys():
545
+ if not hasattr(self, method_name):
546
+ setattr(self, method_name, getattr(self.client, method_name))
547
+ if hasattr(self.client, name):
548
+ return getattr(self.client, name)
549
+ raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
550
+
551
+ def _check_predict_input_type(self, input_type: str) -> None:
552
+ """Checks if the input type is valid for the model.
553
+
554
+ Args:
555
+ input_type (str): The input type to check.
556
+ Returns:
557
+ None
558
+ """
559
+ if not input_type:
560
+ self.load_input_types()
561
+ if len(self.input_types) > 1:
562
+ raise UserError(
563
+ "Model has multiple input types. Please use model.predict() for this multi-modal model."
564
+ )
565
+ else:
566
+ self.input_types = [input_type]
567
+ if self.input_types[0] not in {'image', 'text', 'video', 'audio'}:
568
+ raise UserError(
569
+ f"Got input type {input_type} but expected one of image, text, video, audio."
570
+ )
571
+
572
+ def load_input_types(self) -> None:
573
+ """Loads the input types for the model.
574
+
575
+ Returns:
576
+ None
577
+
578
+ Example:
579
+ >>> from clarifai.client.model import Model
580
+ >>> model = Model("url") # Example URL: https://clarifai.com/clarifai/main/models/general-image-recognition
581
+ or
582
+ >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
583
+ >>> model.load_input_types()
584
+ """
585
+ if self.input_types:
586
+ return self.input_types
587
+ if self.model_info.model_type_id == "":
588
+ self.load_info()
589
+ request = service_pb2.GetModelTypeRequest(
590
+ user_app_id=self.user_app_id,
591
+ model_type_id=self.model_info.model_type_id,
592
+ )
593
+ response = self._grpc_request(self.STUB.GetModelType, request)
594
+ if response.status.code != status_code_pb2.SUCCESS:
595
+ raise Exception(response.status)
596
+ self.input_types = response.model_type.input_fields
597
+
598
+ def _set_runner_selector(
599
+ self,
600
+ compute_cluster_id: str = None,
601
+ nodepool_id: str = None,
602
+ deployment_id: str = None,
603
+ user_id: str = None,
604
+ ):
605
+ runner_selector = None
606
+ if deployment_id and (compute_cluster_id or nodepool_id):
607
+ raise UserError(
608
+ "You can only specify one of deployment_id or compute_cluster_id and nodepool_id."
609
+ )
610
+
611
+ if deployment_id:
612
+ if not user_id and not os.environ.get('CLARIFAI_USER_ID'):
613
+ raise UserError(
614
+ "User ID is required for model prediction with deployment ID, please provide user_id in the method call."
615
+ )
616
+ if not user_id:
617
+ user_id = os.environ.get('CLARIFAI_USER_ID')
618
+ runner_selector = Deployment.get_runner_selector(
619
+ user_id=user_id, deployment_id=deployment_id
620
+ )
621
+ elif compute_cluster_id and nodepool_id:
622
+ if not user_id and not os.environ.get('CLARIFAI_USER_ID'):
623
+ raise UserError(
624
+ "User ID is required for model prediction with compute cluster ID and nodepool ID, please provide user_id in the method call."
625
+ )
626
+ if not user_id:
627
+ user_id = os.environ.get('CLARIFAI_USER_ID')
628
+ runner_selector = Nodepool.get_runner_selector(
629
+ user_id=user_id, compute_cluster_id=compute_cluster_id, nodepool_id=nodepool_id
630
+ )
631
+
632
+ # set the runner selector
633
+ self._runner_selector = runner_selector
634
+
635
+ def predict_by_filepath(
636
+ self,
637
+ filepath: str,
638
+ input_type: str = None,
639
+ inference_params: Dict = {},
640
+ output_config: Dict = {},
641
+ ):
642
+ """Predicts the model based on the given filepath.
643
+
644
+ Args:
645
+ filepath (str): The filepath to predict.
646
+ input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio.
647
+ inference_params (dict): The inference params to override.
648
+ output_config (dict): The output config to override.
649
+ min_value (float): The minimum value of the prediction confidence to filter.
650
+ max_concepts (int): The maximum number of concepts to return.
651
+ select_concepts (list[Concept]): The concepts to select.
652
+
653
+ Example:
654
+ >>> from clarifai.client.model import Model
655
+ >>> model = Model("url") # Example URL: https://clarifai.com/clarifai/main/models/general-image-recognition
656
+ or
657
+ >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
658
+ >>> model_prediction = model.predict_by_filepath('/path/to/image.jpg')
659
+ >>> model_prediction = model.predict_by_filepath('/path/to/text.txt')
660
+ """
661
+ if not os.path.isfile(filepath):
662
+ raise UserError('Invalid filepath.')
663
+
664
+ with open(filepath, "rb") as f:
665
+ file_bytes = f.read()
666
+
667
+ return self.predict_by_bytes(file_bytes, input_type, inference_params, output_config)
668
+
669
+ def predict_by_bytes(
670
+ self,
671
+ input_bytes: bytes,
672
+ input_type: str = None,
673
+ inference_params: Dict = {},
674
+ output_config: Dict = {},
675
+ ):
676
+ """Predicts the model based on the given bytes.
677
+
678
+ Args:
679
+ input_bytes (bytes): File Bytes to predict on.
680
+ input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio.
681
+ inference_params (dict): The inference params to override.
682
+ output_config (dict): The output config to override.
683
+ min_value (float): The minimum value of the prediction confidence to filter.
684
+ max_concepts (int): The maximum number of concepts to return.
685
+ select_concepts (list[Concept]): The concepts to select.
686
+
687
+ Example:
688
+ >>> from clarifai.client.model import Model
689
+ >>> model = Model("https://clarifai.com/openai/chat-completion/models/GPT-4")
690
+ >>> model_prediction = model.predict_by_bytes(b'Write a tweet on future of AI',
691
+ inference_params=dict(temperature=str(0.7), max_tokens=30)))
692
+ """
693
+ self._check_predict_input_type(input_type)
694
+
695
+ if self.input_types[0] == "image":
696
+ input_proto = Inputs.get_input_from_bytes("", image_bytes=input_bytes)
697
+ elif self.input_types[0] == "text":
698
+ input_proto = Inputs.get_input_from_bytes("", text_bytes=input_bytes)
699
+ elif self.input_types[0] == "video":
700
+ input_proto = Inputs.get_input_from_bytes("", video_bytes=input_bytes)
701
+ elif self.input_types[0] == "audio":
702
+ input_proto = Inputs.get_input_from_bytes("", audio_bytes=input_bytes)
703
+
704
+ return self.predict(
705
+ inputs=[input_proto], inference_params=inference_params, output_config=output_config
706
+ )
707
+
708
+ def predict_by_url(
709
+ self,
710
+ url: str,
711
+ input_type: str = None,
712
+ inference_params: Dict = {},
713
+ output_config: Dict = {},
714
+ ):
715
+ """Predicts the model based on the given URL.
716
+
717
+ Args:
718
+ url (str): The URL to predict.
719
+ input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio'.
720
+ inference_params (dict): The inference params to override.
721
+ output_config (dict): The output config to override.
722
+ min_value (float): The minimum value of the prediction confidence to filter.
723
+ max_concepts (int): The maximum number of concepts to return.
724
+ select_concepts (list[Concept]): The concepts to select.
725
+
726
+ Example:
727
+ >>> from clarifai.client.model import Model
728
+ >>> model = Model("url") # Example URL: https://clarifai.com/clarifai/main/models/general-image-recognition
729
+ or
730
+ >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
731
+ >>> model_prediction = model.predict_by_url('url')
732
+ """
733
+ self._check_predict_input_type(input_type)
734
+
735
+ if self.input_types[0] == "image":
736
+ input_proto = Inputs.get_input_from_url("", image_url=url)
737
+ elif self.input_types[0] == "text":
738
+ input_proto = Inputs.get_input_from_url("", text_url=url)
739
+ elif self.input_types[0] == "video":
740
+ input_proto = Inputs.get_input_from_url("", video_url=url)
741
+ elif self.input_types[0] == "audio":
742
+ input_proto = Inputs.get_input_from_url("", audio_url=url)
743
+
744
+ return self.predict(
745
+ inputs=[input_proto], inference_params=inference_params, output_config=output_config
746
+ )
747
+
748
+ def generate(self, *args, **kwargs):
749
+ """
750
+ Calls the model's generate() method with the given arguments.
751
+
752
+ If passed in request_pb2.PostModelOutputsRequest values, will send the model the raw
753
+ protos directly for compatibility with previous versions of the SDK.
754
+ """
755
+
756
+ inputs = None
757
+ if 'inputs' in kwargs:
758
+ inputs = kwargs['inputs']
759
+ elif args:
760
+ inputs = args[0]
761
+ if inputs and isinstance(inputs, list) and isinstance(inputs[0], resources_pb2.Input):
762
+ assert len(args) <= 1, (
763
+ "Cannot pass in raw protos and additional arguments at the same time."
764
+ )
765
+ inference_params = kwargs.get('inference_params', {})
766
+ output_config = kwargs.get('output_config', {})
767
+ return self.client._generate_by_proto(
768
+ inputs=inputs, inference_params=inference_params, output_config=output_config
769
+ )
770
+
771
+ return self.client.generate(*args, **kwargs)
772
+
773
+ def generate_by_filepath(
774
+ self,
775
+ filepath: str,
776
+ input_type: str = None,
777
+ inference_params: Dict = {},
778
+ output_config: Dict = {},
779
+ ):
780
+ """Generate the stream output on model based on the given filepath.
781
+
782
+ Args:
783
+ filepath (str): The filepath to predict.
784
+ input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio.
785
+ inference_params (dict): The inference params to override.
786
+ output_config (dict): The output config to override.
787
+ min_value (float): The minimum value of the prediction confidence to filter.
788
+ max_concepts (int): The maximum number of concepts to return.
789
+ select_concepts (list[Concept]): The concepts to select.
790
+
791
+ Example:
792
+ >>> from clarifai.client.model import Model
793
+ >>> model = Model("url") # Example URL: https://clarifai.com/clarifai/main/models/general-image-recognition
794
+ or
795
+ >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
796
+ >>> stream_response = model.generate_by_filepath('/path/to/image.jpg', 'image', deployment_id='deployment_id')
797
+ >>> list_stream_response = [response for response in stream_response]
798
+ """
799
+ if not os.path.isfile(filepath):
800
+ raise UserError('Invalid filepath.')
801
+
802
+ with open(filepath, "rb") as f:
803
+ file_bytes = f.read()
804
+
805
+ return self.generate_by_bytes(
806
+ input_bytes=file_bytes,
807
+ input_type=input_type,
808
+ inference_params=inference_params,
809
+ output_config=output_config,
810
+ )
811
+
812
+ def generate_by_bytes(
813
+ self,
814
+ input_bytes: bytes,
815
+ input_type: str = None,
816
+ inference_params: Dict = {},
817
+ output_config: Dict = {},
818
+ ):
819
+ """Generate the stream output on model based on the given bytes.
820
+
821
+ Args:
822
+ input_bytes (bytes): File Bytes to predict on.
823
+ input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio.
824
+ inference_params (dict): The inference params to override.
825
+ output_config (dict): The output config to override.
826
+ min_value (float): The minimum value of the prediction confidence to filter.
827
+ max_concepts (int): The maximum number of concepts to return.
828
+ select_concepts (list[Concept]): The concepts to select.
829
+
830
+ Example:
831
+ >>> from clarifai.client.model import Model
832
+ >>> model = Model("https://clarifai.com/openai/chat-completion/models/GPT-4")
833
+ >>> stream_response = model.generate_by_bytes(b'Write a tweet on future of AI',
834
+ deployment_id='deployment_id',
835
+ inference_params=dict(temperature=str(0.7), max_tokens=30)))
836
+ >>> list_stream_response = [response for response in stream_response]
837
+ """
838
+ self._check_predict_input_type(input_type)
839
+
896
840
  if self.input_types[0] == "image":
897
- yield [Inputs.get_input_from_bytes("", image_bytes=input_bytes)]
841
+ input_proto = Inputs.get_input_from_bytes("", image_bytes=input_bytes)
898
842
  elif self.input_types[0] == "text":
899
- yield [Inputs.get_input_from_bytes("", text_bytes=input_bytes)]
843
+ input_proto = Inputs.get_input_from_bytes("", text_bytes=input_bytes)
900
844
  elif self.input_types[0] == "video":
901
- yield [Inputs.get_input_from_bytes("", video_bytes=input_bytes)]
845
+ input_proto = Inputs.get_input_from_bytes("", video_bytes=input_bytes)
902
846
  elif self.input_types[0] == "audio":
903
- yield [Inputs.get_input_from_bytes("", audio_bytes=input_bytes)]
904
-
905
- return self.stream(
906
- inputs=input_generator(), inference_params=inference_params, output_config=output_config)
907
-
908
- def stream_by_url(self,
909
- url_iterator: Iterator[str],
910
- input_type: str = None,
911
- inference_params: Dict = {},
912
- output_config: Dict = {}):
913
- """Stream the model output based on the given URL.
914
-
915
- Args:
916
- url_iterator (Iterator[str]): Iterator of URLs to predict.
917
- input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio.
918
- inference_params (dict): The inference params to override.
919
- output_config (dict): The output config to override.
920
- min_value (float): The minimum value of the prediction confidence to filter.
921
- max_concepts (int): The maximum number of concepts to return.
922
- select_concepts (list[Concept]): The concepts to select.
923
-
924
- Example:
925
- >>> from clarifai.client.model import Model
926
- >>> model = Model("url")
927
- >>> stream_response = model.stream_by_url(iter(['url']), deployment_id='deployment_id')
928
- >>> list_stream_response = [response for response in stream_response]
929
- """
930
- self._check_predict_input_type(input_type)
931
-
932
- def input_generator():
933
- for url in url_iterator:
847
+ input_proto = Inputs.get_input_from_bytes("", audio_bytes=input_bytes)
848
+
849
+ return self.generate(
850
+ inputs=[input_proto], inference_params=inference_params, output_config=output_config
851
+ )
852
+
853
+ def generate_by_url(
854
+ self,
855
+ url: str,
856
+ input_type: str = None,
857
+ inference_params: Dict = {},
858
+ output_config: Dict = {},
859
+ ):
860
+ """Generate the stream output on model based on the given URL.
861
+
862
+ Args:
863
+ url (str): The URL to predict.
864
+ input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio.
865
+ inference_params (dict): The inference params to override.
866
+ output_config (dict): The output config to override.
867
+ min_value (float): The minimum value of the prediction confidence to filter.
868
+ max_concepts (int): The maximum number of concepts to return.
869
+ select_concepts (list[Concept]): The concepts to select.
870
+
871
+ Example:
872
+ >>> from clarifai.client.model import Model
873
+ >>> model = Model("url") # Example URL: https://clarifai.com/clarifai/main/models/general-image-recognition
874
+ or
875
+ >>> model = Model(model_id='model_id', user_id='user_id', app_id='app_id')
876
+ >>> stream_response = model.generate_by_url('url', deployment_id='deployment_id')
877
+ >>> list_stream_response = [response for response in stream_response]
878
+ """
879
+ self._check_predict_input_type(input_type)
880
+
934
881
  if self.input_types[0] == "image":
935
- yield [Inputs.get_input_from_url("", image_url=url)]
882
+ input_proto = Inputs.get_input_from_url("", image_url=url)
936
883
  elif self.input_types[0] == "text":
937
- yield [Inputs.get_input_from_url("", text_url=url)]
884
+ input_proto = Inputs.get_input_from_url("", text_url=url)
938
885
  elif self.input_types[0] == "video":
939
- yield [Inputs.get_input_from_url("", video_url=url)]
886
+ input_proto = Inputs.get_input_from_url("", video_url=url)
940
887
  elif self.input_types[0] == "audio":
941
- yield [Inputs.get_input_from_url("", audio_url=url)]
942
-
943
- return self.stream(
944
- inputs=input_generator(), inference_params=inference_params, output_config=output_config)
945
-
946
- def _override_model_version(self, inference_params: Dict = {}, output_config: Dict = {}) -> None:
947
- """Overrides the model version.
948
-
949
- Args:
950
- inference_params (dict): The inference params to override.
951
- output_config (dict): The output config to override.
952
- min_value (float): The minimum value of the prediction confidence to filter.
953
- max_concepts (int): The maximum number of concepts to return.
954
- select_concepts (list[Concept]): The concepts to select.
955
- sample_ms (int): The number of milliseconds to sample.
956
- """
957
- params = Struct()
958
- if inference_params is not None:
959
- params.update(inference_params)
960
-
961
- self.model_info.model_version.output_info.CopyFrom(
962
- resources_pb2.OutputInfo(
963
- output_config=resources_pb2.OutputConfig(**output_config), params=params))
964
-
965
- def _list_concepts(self) -> List[str]:
966
- """Lists all the concepts for the model type.
967
-
968
- Returns:
969
- concepts (List): List of concepts for the model type.
970
- """
971
- request_data = dict(user_app_id=self.user_app_id)
972
- all_concepts_infos = self.list_pages_generator(self.STUB.ListConcepts,
973
- service_pb2.ListConceptsRequest, request_data)
974
- return [concept_info['concept_id'] for concept_info in all_concepts_infos]
975
-
976
- def load_info(self) -> None:
977
- """Loads the model info."""
978
- request = service_pb2.GetModelRequest(
979
- user_app_id=self.user_app_id,
980
- model_id=self.id,
981
- version_id=self.model_info.model_version.id)
982
- response = self._grpc_request(self.STUB.GetModel, request)
983
-
984
- if response.status.code != status_code_pb2.SUCCESS:
985
- raise Exception(response.status)
986
-
987
- dict_response = MessageToDict(response, preserving_proto_field_name=True)
988
- self.kwargs = self.process_response_keys(dict_response['model'])
989
- self.model_info = resources_pb2.Model()
990
- dict_to_protobuf(self.model_info, self.kwargs)
991
-
992
- def __str__(self):
993
- if len(self.kwargs) < 10:
994
- self.load_info()
995
-
996
- init_params = [param for param in self.kwargs.keys()]
997
- attribute_strings = [
998
- f"{param}={getattr(self.model_info, param)}" for param in init_params
999
- if hasattr(self.model_info, param)
1000
- ]
1001
- return f"Model Details: \n{', '.join(attribute_strings)}\n"
1002
-
1003
- def list_evaluations(self) -> resources_pb2.EvalMetrics:
1004
- """List all eval_metrics of current model version
1005
-
1006
- Raises:
1007
- Exception: Failed to call API
1008
-
1009
- Returns:
1010
- resources_pb2.EvalMetrics
1011
- """
1012
- assert self.model_info.model_version.id, "Model version is empty. Please provide `model_version` as arguments or with a URL as the format '{user_id}/{app_id}/models/{your_model_id}/model_version_id/{your_version_model_id}' when initializing."
1013
- request = service_pb2.ListModelVersionEvaluationsRequest(
1014
- user_app_id=self.user_app_id,
1015
- model_id=self.id,
1016
- model_version_id=self.model_info.model_version.id)
1017
- response = self._grpc_request(self.STUB.ListModelVersionEvaluations, request)
1018
-
1019
- if response.status.code != status_code_pb2.SUCCESS:
1020
- raise Exception(response.status)
1021
-
1022
- return response.eval_metrics
1023
-
1024
- def evaluate(self,
1025
- dataset: Dataset = None,
1026
- dataset_id: str = None,
1027
- dataset_app_id: str = None,
1028
- dataset_user_id: str = None,
1029
- dataset_version_id: str = None,
1030
- eval_id: str = None,
1031
- extended_metrics: dict = None,
1032
- eval_info: dict = None) -> resources_pb2.EvalMetrics:
1033
- """ Run evaluation
1034
-
1035
- Args:
1036
- dataset (Dataset): If Clarifai Dataset is set, it will ignore other arguments prefixed with 'dataset_'.
1037
- dataset_id (str): Dataset Id. Default is None.
1038
- dataset_app_id (str): App ID for cross app evaluation, leave it as None to use Model App ID. Default is None.
1039
- dataset_user_id (str): User ID for cross app evaluation, leave it as None to use Model User ID. Default is None.
1040
- dataset_version_id (str): Dataset version Id. Default is None.
1041
- eval_id (str): Specific ID for the evaluation. You must specify this parameter to either overwrite the result with the dataset ID or format your evaluation in an informative manner. If you don't, it will use random ID from system. Default is None.
1042
- extended_metrics (dict): user custom metrics result. Default is None.
1043
- eval_info (dict): custom eval info. Default is empty dict.
1044
-
1045
- Return
1046
- eval_metrics
1047
-
1048
- """
1049
- assert self.model_info.model_version.id, "Model version is empty. Please provide `model_version` as arguments or with a URL as the format '{user_id}/{app_id}/models/{your_model_id}/model_version_id/{your_version_model_id}' when initializing."
1050
-
1051
- if dataset:
1052
- self.logger.info("Using dataset, ignore other arguments prefixed with 'dataset_'")
1053
- dataset_id = dataset.id
1054
- dataset_app_id = dataset.app_id
1055
- dataset_user_id = dataset.user_id
1056
- dataset_version_id = dataset.version.id
1057
- else:
1058
- self.logger.warning(
1059
- "Arguments prefixed with `dataset_` will be removed soon, please use dataset")
1060
-
1061
- gt_dataset = resources_pb2.Dataset(
1062
- id=dataset_id,
1063
- app_id=dataset_app_id or self.auth_helper.app_id,
1064
- user_id=dataset_user_id or self.auth_helper.user_id,
1065
- version=resources_pb2.DatasetVersion(id=dataset_version_id))
1066
-
1067
- metrics = None
1068
- if isinstance(extended_metrics, dict):
1069
- metrics = Struct()
1070
- metrics.update(extended_metrics)
1071
- metrics = resources_pb2.ExtendedMetrics(user_metrics=metrics)
1072
-
1073
- eval_info_params = None
1074
- if isinstance(eval_info, dict):
1075
- eval_info_params = Struct()
1076
- eval_info_params.update(eval_info)
1077
- eval_info_params = resources_pb2.EvalInfo(params=eval_info_params)
1078
-
1079
- eval_metric = resources_pb2.EvalMetrics(
1080
- id=eval_id,
1081
- model=resources_pb2.Model(
1082
- id=self.id,
1083
- app_id=self.auth_helper.app_id,
1084
- user_id=self.auth_helper.user_id,
1085
- model_version=resources_pb2.ModelVersion(id=self.model_info.model_version.id),
1086
- ),
1087
- extended_metrics=metrics,
1088
- ground_truth_dataset=gt_dataset,
1089
- eval_info=eval_info_params,
1090
- )
1091
- request = service_pb2.PostEvaluationsRequest(
1092
- user_app_id=self.user_app_id,
1093
- eval_metrics=[eval_metric],
1094
- )
1095
- response = self._grpc_request(self.STUB.PostEvaluations, request)
1096
- if response.status.code != status_code_pb2.SUCCESS:
1097
- raise Exception(response.status)
1098
- self.logger.info(
1099
- "\nModel evaluation in progress. Kindly allow a few minutes for completion. Processing time may vary based on the model and dataset sizes."
1100
- )
1101
-
1102
- return response.eval_metrics
1103
-
1104
- def get_eval_by_id(
1105
- self,
1106
- eval_id: str,
1107
- label_counts=False,
1108
- test_set=False,
1109
- binary_metrics=False,
1110
- confusion_matrix=False,
1111
- metrics_by_class=False,
1112
- metrics_by_area=False,
1113
- ) -> resources_pb2.EvalMetrics:
1114
- """Get detail eval_metrics by eval_id with extra metric fields
1115
-
1116
- Args:
1117
- eval_id (str): eval id
1118
- label_counts (bool, optional): Set True to get label counts. Defaults to False.
1119
- test_set (bool, optional): Set True to get test set. Defaults to False.
1120
- binary_metrics (bool, optional): Set True to get binary metric. Defaults to False.
1121
- confusion_matrix (bool, optional): Set True to get confusion matrix. Defaults to False.
1122
- metrics_by_class (bool, optional): Set True to get metrics by class. Defaults to False.
1123
- metrics_by_area (bool, optional): Set True to get metrics by area. Defaults to False.
1124
-
1125
- Raises:
1126
- Exception: Failed to call API
1127
-
1128
- Returns:
1129
- resources_pb2.EvalMetrics: eval_metrics
1130
- """
1131
- request = service_pb2.GetEvaluationRequest(
1132
- user_app_id=self.user_app_id,
1133
- evaluation_id=eval_id,
1134
- fields=resources_pb2.FieldsValue(
1135
- label_counts=label_counts,
1136
- test_set=test_set,
1137
- binary_metrics=binary_metrics,
1138
- confusion_matrix=confusion_matrix,
1139
- metrics_by_class=metrics_by_class,
1140
- metrics_by_area=metrics_by_area,
1141
- ))
1142
- response = self._grpc_request(self.STUB.GetEvaluation, request)
1143
-
1144
- if response.status.code != status_code_pb2.SUCCESS:
1145
- raise Exception(response.status)
1146
-
1147
- return response.eval_metrics
1148
-
1149
- def get_latest_eval(self,
1150
- label_counts=False,
1151
- test_set=False,
1152
- binary_metrics=False,
1153
- confusion_matrix=False,
1154
- metrics_by_class=False,
1155
- metrics_by_area=False) -> Union[resources_pb2.EvalMetrics, None]:
1156
- """
1157
- Run `get_eval_by_id` method with latest `eval_id`
1158
-
1159
- Args:
1160
- label_counts (bool, optional): Set True to get label counts. Defaults to False.
1161
- test_set (bool, optional): Set True to get test set. Defaults to False.
1162
- binary_metrics (bool, optional): Set True to get binary metric. Defaults to False.
1163
- confusion_matrix (bool, optional): Set True to get confusion matrix. Defaults to False.
1164
- metrics_by_class (bool, optional): Set True to get metrics by class. Defaults to False.
1165
- metrics_by_area (bool, optional): Set True to get metrics by area. Defaults to False.
1166
-
1167
- Returns:
1168
- eval_metric if model is evaluated otherwise None.
1169
-
1170
- """
1171
-
1172
- _latest = self.list_evaluations()[0]
1173
- result = None
1174
- if _latest.status.code == status_code_pb2.MODEL_EVALUATED:
1175
- result = self.get_eval_by_id(
1176
- eval_id=_latest.id,
1177
- label_counts=label_counts,
1178
- test_set=test_set,
1179
- binary_metrics=binary_metrics,
1180
- confusion_matrix=confusion_matrix,
1181
- metrics_by_class=metrics_by_class,
1182
- metrics_by_area=metrics_by_area)
1183
-
1184
- return result
1185
-
1186
- def get_eval_by_dataset(self, dataset: Dataset) -> List[resources_pb2.EvalMetrics]:
1187
- """Get all eval data of dataset
1188
-
1189
- Args:
1190
- dataset (Dataset): Clarifai dataset
1191
-
1192
- Returns:
1193
- List[resources_pb2.EvalMetrics]
1194
- """
1195
- _id = dataset.id
1196
- app = dataset.app_id or self.app_id
1197
- user_id = dataset.user_id or self.user_id
1198
- version = dataset.version.id
1199
-
1200
- list_eval: resources_pb2.EvalMetrics = self.list_evaluations()
1201
- outputs = []
1202
- for _eval in list_eval:
1203
- if _eval.status.code == status_code_pb2.MODEL_EVALUATED:
1204
- gt_ds = _eval.ground_truth_dataset
1205
- if (_id == gt_ds.id and user_id == gt_ds.user_id and app == gt_ds.app_id):
1206
- if not version or version == gt_ds.version.id:
1207
- outputs.append(_eval)
1208
-
1209
- return outputs
1210
-
1211
- def get_raw_eval(self,
1212
- dataset: Dataset = None,
1213
- eval_id: str = None,
1214
- return_format: str = 'array') -> Union[resources_pb2.EvalTestSetEntry, Tuple[
1215
- np.array, np.array, list, List[Input]], Tuple[List[dict], List[dict]]]:
1216
- """Get ground truths, predictions and input information. Do not pass dataset and eval_id at same time
1217
-
1218
- Args:
1219
- dataset (Dataset): Clarifai dataset, get eval data of latest eval result of dataset.
1220
- eval_id (str): Evaluation ID, get eval data of specific eval id.
1221
- return_format (str, optional): Choice {proto, array, coco}. !Note that `coco` is only applicable for 'visual-detector'. Defaults to 'array'.
1222
-
1223
- Returns:
1224
-
1225
- Depends on `return_format`.
1226
-
1227
- * if return_format == proto
1228
- `resources_pb2.EvalTestSetEntry`
1229
-
1230
- * if return_format == array
1231
- `Tuple(np.array, np.array, List[str], List[Input])`: Tuple has 4 elements (y, y_pred, concept_ids, inputs).
1232
- y, y_pred, concept_ids can be used to compute metrics. 'inputs' can be use to download
1233
- - if model is 'classifier': 'y' and 'y_pred' are both arrays with a shape of (num_inputs,)
1234
- - if model is 'visual-detector': 'y' and 'y_pred' are arrays with a shape of (num_inputs,), where each element is array has shape (num_annotation, 6) consists of [x_min, y_min, x_max, y_max, concept_index, score]. The score is always 1 for 'y'
1235
-
1236
- * if return_format == coco: Applicable only for 'visual-detector'
1237
- `Tuple[List[Dict], List[Dict]]`: Tuple has 2 elemnts where first element is COCO Ground Truth and last one is COCO Prediction Annotation
1238
-
1239
- Example Usages:
1240
- ------
1241
- * Evaluate `visual-classifier` using sklearn
1242
-
1243
- ```python
1244
- import os
1245
- from sklearn.metrics import accuracy_score
1246
- from sklearn.metrics import classification_report
1247
- import numpy as np
1248
- from clarifai.client.model import Model
1249
- from clarifai.client.dataset import Dataset
1250
- os.environ["CLARIFAI_PAT"] = "???"
1251
- model = Model(url="url/of/model/includes/version-id")
1252
- dataset = Dataset(dataset_id="dataset-id")
1253
- y, y_pred, clss, input_protos = model.get_raw_eval(dataset, return_format="array")
1254
- y = np.argmax(y, axis=1)
1255
- y_pred = np.argmax(y_pred, axis=1)
1256
- report = classification_report(y, y_pred, target_names=clss)
1257
- print(report)
1258
- acc = accuracy_score(y, y_pred)
1259
- print("acc ", acc)
1260
- ```
1261
-
1262
- * Evaluate `visual-detector` using COCOeval
1263
-
1264
- ```python
1265
- import os
1266
- import json
1267
- from pycocotools.coco import COCO
1268
- from pycocotools.cocoeval import COCOeval
1269
- from clarifai.client.model import Model
1270
- from clarifai.client.dataset import Dataset
1271
- os.environ["CLARIFAI_PAT"] = "???" # Insert your PAT
1272
- model = Model(url=model_url)
1273
- dataset = Dataset(url=dataset_url)
1274
- y, y_pred = model.get_raw_eval(dataset, return_format="coco")
1275
- # save as files to load in COCO API
1276
- def save_annot(d, path):
1277
- with open(path, "w") as fp:
1278
- json.dump(d, fp, indent=2)
1279
- gt_path = os.path.join("gt.json")
1280
- pred_path = os.path.join("pred.json")
1281
- save_annot(y, gt_path)
1282
- save_annot(y_pred, pred_path)
1283
-
1284
- cocoGt = COCO(gt_path)
1285
- cocoPred = COCO(pred_path)
1286
- cocoEval = COCOeval(cocoGt, cocoPred, "bbox")
1287
- cocoEval.evaluate()
1288
- cocoEval.accumulate()
1289
- cocoEval.summarize() # Print out result of all classes with all area type
1290
- # Example:
1291
- # Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.863
1292
- # Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.973
1293
- # Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.939
1294
- # ...
1295
- ```
1296
-
1297
- """
1298
- from clarifai.utils.evaluation.testset_annotation_parser import (
1299
- parse_eval_annotation_classifier, parse_eval_annotation_detector,
1300
- parse_eval_annotation_detector_coco)
1301
-
1302
- valid_model_types = ["visual-classifier", "text-classifier", "visual-detector"]
1303
- supported_format = ['proto', 'array', 'coco']
1304
- assert return_format in supported_format, ValueError(
1305
- f"Expected return_format in {supported_format}, got {return_format}")
1306
- self.load_info()
1307
- model_type_id = self.model_info.model_type_id
1308
- assert model_type_id in valid_model_types, \
1309
- f"This method only supports model types {valid_model_types}, but your model type is {self.model_info.model_type_id}."
1310
- assert not (dataset and
1311
- eval_id), "Using both `dataset` and `eval_id`, but only one should be passed."
1312
- assert not dataset or not eval_id, "Please provide either `dataset` or `eval_id`, but nothing was passed."
1313
- if model_type_id.endswith("-classifier") and return_format == "coco":
1314
- raise ValueError(
1315
- f"return_format coco only applies for `visual-detector`, however your model is `{model_type_id}`"
1316
- )
1317
-
1318
- if dataset:
1319
- eval_by_ds = self.get_eval_by_dataset(dataset)
1320
- if len(eval_by_ds) == 0:
1321
- raise Exception(f"Model is not valuated with dataset: {dataset}")
1322
- eval_id = eval_by_ds[0].id
1323
-
1324
- detail_eval_data = self.get_eval_by_id(eval_id=eval_id, test_set=True, metrics_by_class=True)
1325
-
1326
- if return_format == "proto":
1327
- return detail_eval_data.test_set
1328
- else:
1329
- if model_type_id.endswith("-classifier"):
1330
- return parse_eval_annotation_classifier(detail_eval_data)
1331
- elif model_type_id == "visual-detector":
1332
- if return_format == "array":
1333
- return parse_eval_annotation_detector(detail_eval_data)
1334
- elif return_format == "coco":
1335
- return parse_eval_annotation_detector_coco(detail_eval_data)
1336
-
1337
- def export(self, export_dir: str = None) -> None:
1338
- """Export the model, stores the exported model as model.tar file
1339
-
1340
- Args:
1341
- export_dir (str, optional): If provided, the exported model will be saved in the specified directory else export status will be shown. Defaults to None.
1342
-
1343
- Example:
1344
- >>> from clarifai.client.model import Model
1345
- >>> model = Model("url")
1346
- >>> model.export()
1347
- or
1348
- >>> model.export('/path/to/export_model_dir')
1349
- """
1350
- assert self.model_info.model_version.id, "Model version ID is missing. Please provide a `model_version` with a valid `id` as an argument or as a URL in the following format: '{user_id}/{app_id}/models/{your_model_id}/model_version_id/{your_version_model_id}' when initializing."
1351
- if export_dir:
1352
- try:
1353
- if not os.path.exists(export_dir):
1354
- os.makedirs(export_dir)
1355
- except OSError as e:
1356
- raise Exception(f"An error occurred while creating the directory: {e}")
1357
-
1358
- def _get_export_response():
1359
- get_export_request = service_pb2.GetModelVersionExportRequest(
1360
- user_app_id=self.user_app_id,
1361
- model_id=self.id,
1362
- version_id=self.model_info.model_version.id,
1363
- )
1364
- response = self._grpc_request(self.STUB.GetModelVersionExport, get_export_request)
1365
-
1366
- if response.status.code != status_code_pb2.SUCCESS and response.status.code != status_code_pb2.CONN_DOES_NOT_EXIST:
1367
- raise Exception(response.status)
1368
-
1369
- return response
1370
-
1371
- def _download_exported_model(
1372
- get_model_export_response: service_pb2.SingleModelVersionExportResponse,
1373
- local_filepath: str):
1374
- model_export_url = get_model_export_response.export.url
1375
- model_export_file_size = get_model_export_response.export.size
1376
-
1377
- with open(local_filepath, 'wb') as f:
1378
- progress = tqdm(
1379
- total=model_export_file_size, unit='B', unit_scale=True, desc="Exporting model")
1380
- downloaded_size = 0
1381
- range_size = RANGE_SIZE
1382
- chunk_size = CHUNK_SIZE
1383
- retry = False
1384
- retry_count = 0
1385
- while downloaded_size < model_export_file_size:
1386
- if downloaded_size + range_size >= model_export_file_size:
1387
- range_header = f"bytes={downloaded_size}-"
1388
- else:
1389
- range_header = f"bytes={downloaded_size}-{(downloaded_size+range_size-1)}"
1390
- try:
1391
- session = requests.Session()
1392
- retries = Retry(total=5, backoff_factor=0.1, status_forcelist=[500, 502, 503, 504])
1393
- session.mount('https://', HTTPAdapter(max_retries=retries))
1394
- session.headers.update({'Authorization': self.metadata[0][1], 'Range': range_header})
1395
- response = session.get(model_export_url, stream=True)
1396
- response.raise_for_status()
1397
-
1398
- for chunk in response.iter_content(chunk_size=chunk_size):
1399
- f.write(chunk)
1400
- progress.update(len(chunk))
1401
- f.flush()
1402
- os.fsync(f.fileno())
1403
- downloaded_size += range_size
1404
- if not retry:
1405
- range_size = (
1406
- range_size * 2) if (range_size * 2) < MAX_RANGE_SIZE else MAX_RANGE_SIZE
1407
- chunk_size = (
1408
- chunk_size * 2) if (chunk_size * 2) < MAX_CHUNK_SIZE else MAX_CHUNK_SIZE
1409
- except Exception as e:
1410
- self.logger.error(f"Error downloading model: {e}")
1411
- range_size = (
1412
- range_size // 2) if (range_size // 2) > MIN_RANGE_SIZE else MIN_RANGE_SIZE
1413
- chunk_size = (
1414
- chunk_size // 2) if (chunk_size // 2) > MIN_CHUNK_SIZE else MIN_CHUNK_SIZE
1415
- retry = True
1416
- retry_count += 1
1417
- f.seek(downloaded_size)
1418
- progress.reset(total=model_export_file_size)
1419
- progress.update(downloaded_size)
1420
- if retry_count > 5:
1421
- break
1422
- progress.close()
1423
-
1424
- self.logger.info(
1425
- f"Model ID {self.id} with version {self.model_info.model_version.id} exported successfully to {export_dir}/model.tar"
1426
- )
1427
-
1428
- get_export_response = _get_export_response()
1429
- if get_export_response.status.code == status_code_pb2.CONN_DOES_NOT_EXIST:
1430
- put_export_request = service_pb2.PutModelVersionExportsRequest(
1431
- user_app_id=self.user_app_id,
1432
- model_id=self.id,
1433
- version_id=self.model_info.model_version.id,
1434
- )
1435
-
1436
- response = self._grpc_request(self.STUB.PutModelVersionExports, put_export_request)
1437
- if response.status.code != status_code_pb2.SUCCESS:
1438
- raise Exception(response.status)
1439
-
1440
- self.logger.info(
1441
- f"Export process has started for Model ID {self.id}, Version {self.model_info.model_version.id}"
1442
- )
1443
- if export_dir:
1444
- start_time = time.time()
1445
- backoff_iterator = BackoffIterator(10)
1446
- while True:
1447
- get_export_response = _get_export_response()
1448
- if (get_export_response.export.status.code == status_code_pb2.MODEL_EXPORTING or \
1449
- get_export_response.export.status.code == status_code_pb2.MODEL_EXPORT_PENDING) and \
1450
- time.time() - start_time < MODEL_EXPORT_TIMEOUT:
888
+ input_proto = Inputs.get_input_from_url("", audio_url=url)
889
+
890
+ return self.generate(
891
+ inputs=[input_proto], inference_params=inference_params, output_config=output_config
892
+ )
893
+
894
+ def stream(self, *args, **kwargs):
895
+ """
896
+ Calls the model's stream() method with the given arguments.
897
+
898
+ If passed in request_pb2.PostModelOutputsRequest values, will send the model the raw
899
+ protos directly for compatibility with previous versions of the SDK.
900
+ """
901
+
902
+ use_proto_call = False
903
+ inputs = None
904
+ if 'inputs' in kwargs:
905
+ inputs = kwargs['inputs']
906
+ elif args:
907
+ inputs = args[0]
908
+ if inputs and isinstance(inputs, Iterable):
909
+ inputs_iter = inputs
910
+ try:
911
+ peek = next(inputs_iter)
912
+ except StopIteration:
913
+ pass
914
+ else:
915
+ use_proto_call = (
916
+ peek and isinstance(peek, list) and isinstance(peek[0], resources_pb2.Input)
917
+ )
918
+ # put back the peeked value
919
+ if inputs_iter is inputs:
920
+ inputs = itertools.chain([peek], inputs_iter)
921
+ if 'inputs' in kwargs:
922
+ kwargs['inputs'] = inputs
923
+ else:
924
+ args = (inputs,) + args[1:]
925
+
926
+ if use_proto_call:
927
+ assert len(args) <= 1, (
928
+ "Cannot pass in raw protos and additional arguments at the same time."
929
+ )
930
+ inference_params = kwargs.get('inference_params', {})
931
+ output_config = kwargs.get('output_config', {})
932
+ return self.client._stream_by_proto(
933
+ inputs=inputs, inference_params=inference_params, output_config=output_config
934
+ )
935
+
936
+ return self.client.stream(*args, **kwargs)
937
+
938
+ def stream_by_filepath(
939
+ self,
940
+ filepath: str,
941
+ input_type: str = None,
942
+ inference_params: Dict = {},
943
+ output_config: Dict = {},
944
+ ):
945
+ """Stream the model output based on the given filepath.
946
+
947
+ Args:
948
+ filepath (str): The filepath to predict.
949
+ input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio.
950
+ inference_params (dict): The inference params to override.
951
+ output_config (dict): The output config to override.
952
+ min_value (float): The minimum value of the prediction confidence to filter.
953
+ max_concepts (int): The maximum number of concepts to return.
954
+ select_concepts (list[Concept]): The concepts to select.
955
+
956
+ Example:
957
+ >>> from clarifai.client.model import Model
958
+ >>> model = Model("url")
959
+ >>> stream_response = model.stream_by_filepath('/path/to/image.jpg', deployment_id='deployment_id')
960
+ >>> list_stream_response = [response for response in stream_response]
961
+ """
962
+ if not os.path.isfile(filepath):
963
+ raise UserError('Invalid filepath.')
964
+
965
+ with open(filepath, "rb") as f:
966
+ file_bytes = f.read()
967
+
968
+ return self.stream_by_bytes(
969
+ input_bytes_iterator=iter([file_bytes]),
970
+ input_type=input_type,
971
+ inference_params=inference_params,
972
+ output_config=output_config,
973
+ )
974
+
975
+ def stream_by_bytes(
976
+ self,
977
+ input_bytes_iterator: Iterator[bytes],
978
+ input_type: str = None,
979
+ inference_params: Dict = {},
980
+ output_config: Dict = {},
981
+ ):
982
+ """Stream the model output based on the given bytes.
983
+
984
+ Args:
985
+ input_bytes_iterator (Iterator[bytes]): Iterator of file bytes to predict on.
986
+ input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio.
987
+ inference_params (dict): The inference params to override.
988
+ output_config (dict): The output config to override.
989
+ min_value (float): The minimum value of the prediction confidence to filter.
990
+ max_concepts (int): The maximum number of concepts to return.
991
+ select_concepts (list[Concept]): The concepts to select.
992
+
993
+ Example:
994
+ >>> from clarifai.client.model import Model
995
+ >>> model = Model("https://clarifai.com/openai/chat-completion/models/GPT-4")
996
+ >>> stream_response = model.stream_by_bytes(iter([b'Write a tweet on future of AI']),
997
+ deployment_id='deployment_id',
998
+ inference_params=dict(temperature=str(0.7), max_tokens=30)))
999
+ >>> list_stream_response = [response for response in stream_response]
1000
+ """
1001
+ self._check_predict_input_type(input_type)
1002
+
1003
+ def input_generator():
1004
+ for input_bytes in input_bytes_iterator:
1005
+ if self.input_types[0] == "image":
1006
+ yield [Inputs.get_input_from_bytes("", image_bytes=input_bytes)]
1007
+ elif self.input_types[0] == "text":
1008
+ yield [Inputs.get_input_from_bytes("", text_bytes=input_bytes)]
1009
+ elif self.input_types[0] == "video":
1010
+ yield [Inputs.get_input_from_bytes("", video_bytes=input_bytes)]
1011
+ elif self.input_types[0] == "audio":
1012
+ yield [Inputs.get_input_from_bytes("", audio_bytes=input_bytes)]
1013
+
1014
+ return self.stream(
1015
+ inputs=input_generator(),
1016
+ inference_params=inference_params,
1017
+ output_config=output_config,
1018
+ )
1019
+
1020
+ def stream_by_url(
1021
+ self,
1022
+ url_iterator: Iterator[str],
1023
+ input_type: str = None,
1024
+ inference_params: Dict = {},
1025
+ output_config: Dict = {},
1026
+ ):
1027
+ """Stream the model output based on the given URL.
1028
+
1029
+ Args:
1030
+ url_iterator (Iterator[str]): Iterator of URLs to predict.
1031
+ input_type (str, optional): The type of input. Can be 'image', 'text', 'video' or 'audio.
1032
+ inference_params (dict): The inference params to override.
1033
+ output_config (dict): The output config to override.
1034
+ min_value (float): The minimum value of the prediction confidence to filter.
1035
+ max_concepts (int): The maximum number of concepts to return.
1036
+ select_concepts (list[Concept]): The concepts to select.
1037
+
1038
+ Example:
1039
+ >>> from clarifai.client.model import Model
1040
+ >>> model = Model("url")
1041
+ >>> stream_response = model.stream_by_url(iter(['url']), deployment_id='deployment_id')
1042
+ >>> list_stream_response = [response for response in stream_response]
1043
+ """
1044
+ self._check_predict_input_type(input_type)
1045
+
1046
+ def input_generator():
1047
+ for url in url_iterator:
1048
+ if self.input_types[0] == "image":
1049
+ yield [Inputs.get_input_from_url("", image_url=url)]
1050
+ elif self.input_types[0] == "text":
1051
+ yield [Inputs.get_input_from_url("", text_url=url)]
1052
+ elif self.input_types[0] == "video":
1053
+ yield [Inputs.get_input_from_url("", video_url=url)]
1054
+ elif self.input_types[0] == "audio":
1055
+ yield [Inputs.get_input_from_url("", audio_url=url)]
1056
+
1057
+ return self.stream(
1058
+ inputs=input_generator(),
1059
+ inference_params=inference_params,
1060
+ output_config=output_config,
1061
+ )
1062
+
1063
+ def _override_model_version(
1064
+ self, inference_params: Dict = {}, output_config: Dict = {}
1065
+ ) -> None:
1066
+ """Overrides the model version.
1067
+
1068
+ Args:
1069
+ inference_params (dict): The inference params to override.
1070
+ output_config (dict): The output config to override.
1071
+ min_value (float): The minimum value of the prediction confidence to filter.
1072
+ max_concepts (int): The maximum number of concepts to return.
1073
+ select_concepts (list[Concept]): The concepts to select.
1074
+ sample_ms (int): The number of milliseconds to sample.
1075
+ """
1076
+ params = Struct()
1077
+ if inference_params is not None:
1078
+ params.update(inference_params)
1079
+
1080
+ self.model_info.model_version.output_info.CopyFrom(
1081
+ resources_pb2.OutputInfo(
1082
+ output_config=resources_pb2.OutputConfig(**output_config), params=params
1083
+ )
1084
+ )
1085
+
1086
+ def _list_concepts(self) -> List[str]:
1087
+ """Lists all the concepts for the model type.
1088
+
1089
+ Returns:
1090
+ concepts (List): List of concepts for the model type.
1091
+ """
1092
+ request_data = dict(user_app_id=self.user_app_id)
1093
+ all_concepts_infos = self.list_pages_generator(
1094
+ self.STUB.ListConcepts, service_pb2.ListConceptsRequest, request_data
1095
+ )
1096
+ return [concept_info['concept_id'] for concept_info in all_concepts_infos]
1097
+
1098
+ def load_info(self) -> None:
1099
+ """Loads the model info."""
1100
+ request = service_pb2.GetModelRequest(
1101
+ user_app_id=self.user_app_id,
1102
+ model_id=self.id,
1103
+ version_id=self.model_info.model_version.id,
1104
+ )
1105
+ response = self._grpc_request(self.STUB.GetModel, request)
1106
+
1107
+ if response.status.code != status_code_pb2.SUCCESS:
1108
+ raise Exception(response.status)
1109
+
1110
+ dict_response = MessageToDict(response, preserving_proto_field_name=True)
1111
+ self.kwargs = self.process_response_keys(dict_response['model'])
1112
+ self.model_info = resources_pb2.Model()
1113
+ dict_to_protobuf(self.model_info, self.kwargs)
1114
+
1115
+ def __str__(self):
1116
+ if len(self.kwargs) < 10:
1117
+ self.load_info()
1118
+
1119
+ init_params = [param for param in self.kwargs.keys()]
1120
+ attribute_strings = [
1121
+ f"{param}={getattr(self.model_info, param)}"
1122
+ for param in init_params
1123
+ if hasattr(self.model_info, param)
1124
+ ]
1125
+ return f"Model Details: \n{', '.join(attribute_strings)}\n"
1126
+
1127
+ def list_evaluations(self) -> resources_pb2.EvalMetrics:
1128
+ """List all eval_metrics of current model version
1129
+
1130
+ Raises:
1131
+ Exception: Failed to call API
1132
+
1133
+ Returns:
1134
+ resources_pb2.EvalMetrics
1135
+ """
1136
+ assert self.model_info.model_version.id, (
1137
+ "Model version is empty. Please provide `model_version` as arguments or with a URL as the format '{user_id}/{app_id}/models/{your_model_id}/model_version_id/{your_version_model_id}' when initializing."
1138
+ )
1139
+ request = service_pb2.ListModelVersionEvaluationsRequest(
1140
+ user_app_id=self.user_app_id,
1141
+ model_id=self.id,
1142
+ model_version_id=self.model_info.model_version.id,
1143
+ )
1144
+ response = self._grpc_request(self.STUB.ListModelVersionEvaluations, request)
1145
+
1146
+ if response.status.code != status_code_pb2.SUCCESS:
1147
+ raise Exception(response.status)
1148
+
1149
+ return response.eval_metrics
1150
+
1151
+ def evaluate(
1152
+ self,
1153
+ dataset: Dataset = None,
1154
+ dataset_id: str = None,
1155
+ dataset_app_id: str = None,
1156
+ dataset_user_id: str = None,
1157
+ dataset_version_id: str = None,
1158
+ eval_id: str = None,
1159
+ extended_metrics: dict = None,
1160
+ eval_info: dict = None,
1161
+ ) -> resources_pb2.EvalMetrics:
1162
+ """Run evaluation
1163
+
1164
+ Args:
1165
+ dataset (Dataset): If Clarifai Dataset is set, it will ignore other arguments prefixed with 'dataset_'.
1166
+ dataset_id (str): Dataset Id. Default is None.
1167
+ dataset_app_id (str): App ID for cross app evaluation, leave it as None to use Model App ID. Default is None.
1168
+ dataset_user_id (str): User ID for cross app evaluation, leave it as None to use Model User ID. Default is None.
1169
+ dataset_version_id (str): Dataset version Id. Default is None.
1170
+ eval_id (str): Specific ID for the evaluation. You must specify this parameter to either overwrite the result with the dataset ID or format your evaluation in an informative manner. If you don't, it will use random ID from system. Default is None.
1171
+ extended_metrics (dict): user custom metrics result. Default is None.
1172
+ eval_info (dict): custom eval info. Default is empty dict.
1173
+
1174
+ Return
1175
+ eval_metrics
1176
+
1177
+ """
1178
+ assert self.model_info.model_version.id, (
1179
+ "Model version is empty. Please provide `model_version` as arguments or with a URL as the format '{user_id}/{app_id}/models/{your_model_id}/model_version_id/{your_version_model_id}' when initializing."
1180
+ )
1181
+
1182
+ if dataset:
1183
+ self.logger.info("Using dataset, ignore other arguments prefixed with 'dataset_'")
1184
+ dataset_id = dataset.id
1185
+ dataset_app_id = dataset.app_id
1186
+ dataset_user_id = dataset.user_id
1187
+ dataset_version_id = dataset.version.id
1188
+ else:
1189
+ self.logger.warning(
1190
+ "Arguments prefixed with `dataset_` will be removed soon, please use dataset"
1191
+ )
1192
+
1193
+ gt_dataset = resources_pb2.Dataset(
1194
+ id=dataset_id,
1195
+ app_id=dataset_app_id or self.auth_helper.app_id,
1196
+ user_id=dataset_user_id or self.auth_helper.user_id,
1197
+ version=resources_pb2.DatasetVersion(id=dataset_version_id),
1198
+ )
1199
+
1200
+ metrics = None
1201
+ if isinstance(extended_metrics, dict):
1202
+ metrics = Struct()
1203
+ metrics.update(extended_metrics)
1204
+ metrics = resources_pb2.ExtendedMetrics(user_metrics=metrics)
1205
+
1206
+ eval_info_params = None
1207
+ if isinstance(eval_info, dict):
1208
+ eval_info_params = Struct()
1209
+ eval_info_params.update(eval_info)
1210
+ eval_info_params = resources_pb2.EvalInfo(params=eval_info_params)
1211
+
1212
+ eval_metric = resources_pb2.EvalMetrics(
1213
+ id=eval_id,
1214
+ model=resources_pb2.Model(
1215
+ id=self.id,
1216
+ app_id=self.auth_helper.app_id,
1217
+ user_id=self.auth_helper.user_id,
1218
+ model_version=resources_pb2.ModelVersion(id=self.model_info.model_version.id),
1219
+ ),
1220
+ extended_metrics=metrics,
1221
+ ground_truth_dataset=gt_dataset,
1222
+ eval_info=eval_info_params,
1223
+ )
1224
+ request = service_pb2.PostEvaluationsRequest(
1225
+ user_app_id=self.user_app_id,
1226
+ eval_metrics=[eval_metric],
1227
+ )
1228
+ response = self._grpc_request(self.STUB.PostEvaluations, request)
1229
+ if response.status.code != status_code_pb2.SUCCESS:
1230
+ raise Exception(response.status)
1231
+ self.logger.info(
1232
+ "\nModel evaluation in progress. Kindly allow a few minutes for completion. Processing time may vary based on the model and dataset sizes."
1233
+ )
1234
+
1235
+ return response.eval_metrics
1236
+
1237
+ def get_eval_by_id(
1238
+ self,
1239
+ eval_id: str,
1240
+ label_counts=False,
1241
+ test_set=False,
1242
+ binary_metrics=False,
1243
+ confusion_matrix=False,
1244
+ metrics_by_class=False,
1245
+ metrics_by_area=False,
1246
+ ) -> resources_pb2.EvalMetrics:
1247
+ """Get detail eval_metrics by eval_id with extra metric fields
1248
+
1249
+ Args:
1250
+ eval_id (str): eval id
1251
+ label_counts (bool, optional): Set True to get label counts. Defaults to False.
1252
+ test_set (bool, optional): Set True to get test set. Defaults to False.
1253
+ binary_metrics (bool, optional): Set True to get binary metric. Defaults to False.
1254
+ confusion_matrix (bool, optional): Set True to get confusion matrix. Defaults to False.
1255
+ metrics_by_class (bool, optional): Set True to get metrics by class. Defaults to False.
1256
+ metrics_by_area (bool, optional): Set True to get metrics by area. Defaults to False.
1257
+
1258
+ Raises:
1259
+ Exception: Failed to call API
1260
+
1261
+ Returns:
1262
+ resources_pb2.EvalMetrics: eval_metrics
1263
+ """
1264
+ request = service_pb2.GetEvaluationRequest(
1265
+ user_app_id=self.user_app_id,
1266
+ evaluation_id=eval_id,
1267
+ fields=resources_pb2.FieldsValue(
1268
+ label_counts=label_counts,
1269
+ test_set=test_set,
1270
+ binary_metrics=binary_metrics,
1271
+ confusion_matrix=confusion_matrix,
1272
+ metrics_by_class=metrics_by_class,
1273
+ metrics_by_area=metrics_by_area,
1274
+ ),
1275
+ )
1276
+ response = self._grpc_request(self.STUB.GetEvaluation, request)
1277
+
1278
+ if response.status.code != status_code_pb2.SUCCESS:
1279
+ raise Exception(response.status)
1280
+
1281
+ return response.eval_metrics
1282
+
1283
+ def get_latest_eval(
1284
+ self,
1285
+ label_counts=False,
1286
+ test_set=False,
1287
+ binary_metrics=False,
1288
+ confusion_matrix=False,
1289
+ metrics_by_class=False,
1290
+ metrics_by_area=False,
1291
+ ) -> Union[resources_pb2.EvalMetrics, None]:
1292
+ """
1293
+ Run `get_eval_by_id` method with latest `eval_id`
1294
+
1295
+ Args:
1296
+ label_counts (bool, optional): Set True to get label counts. Defaults to False.
1297
+ test_set (bool, optional): Set True to get test set. Defaults to False.
1298
+ binary_metrics (bool, optional): Set True to get binary metric. Defaults to False.
1299
+ confusion_matrix (bool, optional): Set True to get confusion matrix. Defaults to False.
1300
+ metrics_by_class (bool, optional): Set True to get metrics by class. Defaults to False.
1301
+ metrics_by_area (bool, optional): Set True to get metrics by area. Defaults to False.
1302
+
1303
+ Returns:
1304
+ eval_metric if model is evaluated otherwise None.
1305
+
1306
+ """
1307
+
1308
+ _latest = self.list_evaluations()[0]
1309
+ result = None
1310
+ if _latest.status.code == status_code_pb2.MODEL_EVALUATED:
1311
+ result = self.get_eval_by_id(
1312
+ eval_id=_latest.id,
1313
+ label_counts=label_counts,
1314
+ test_set=test_set,
1315
+ binary_metrics=binary_metrics,
1316
+ confusion_matrix=confusion_matrix,
1317
+ metrics_by_class=metrics_by_class,
1318
+ metrics_by_area=metrics_by_area,
1319
+ )
1320
+
1321
+ return result
1322
+
1323
+ def get_eval_by_dataset(self, dataset: Dataset) -> List[resources_pb2.EvalMetrics]:
1324
+ """Get all eval data of dataset
1325
+
1326
+ Args:
1327
+ dataset (Dataset): Clarifai dataset
1328
+
1329
+ Returns:
1330
+ List[resources_pb2.EvalMetrics]
1331
+ """
1332
+ _id = dataset.id
1333
+ app = dataset.app_id or self.app_id
1334
+ user_id = dataset.user_id or self.user_id
1335
+ version = dataset.version.id
1336
+
1337
+ list_eval: resources_pb2.EvalMetrics = self.list_evaluations()
1338
+ outputs = []
1339
+ for _eval in list_eval:
1340
+ if _eval.status.code == status_code_pb2.MODEL_EVALUATED:
1341
+ gt_ds = _eval.ground_truth_dataset
1342
+ if _id == gt_ds.id and user_id == gt_ds.user_id and app == gt_ds.app_id:
1343
+ if not version or version == gt_ds.version.id:
1344
+ outputs.append(_eval)
1345
+
1346
+ return outputs
1347
+
1348
+ def get_raw_eval(
1349
+ self, dataset: Dataset = None, eval_id: str = None, return_format: str = 'array'
1350
+ ) -> Union[
1351
+ resources_pb2.EvalTestSetEntry,
1352
+ Tuple[np.array, np.array, list, List[Input]],
1353
+ Tuple[List[dict], List[dict]],
1354
+ ]:
1355
+ """Get ground truths, predictions and input information. Do not pass dataset and eval_id at same time
1356
+
1357
+ Args:
1358
+ dataset (Dataset): Clarifai dataset, get eval data of latest eval result of dataset.
1359
+ eval_id (str): Evaluation ID, get eval data of specific eval id.
1360
+ return_format (str, optional): Choice {proto, array, coco}. !Note that `coco` is only applicable for 'visual-detector'. Defaults to 'array'.
1361
+
1362
+ Returns:
1363
+
1364
+ Depends on `return_format`.
1365
+
1366
+ * if return_format == proto
1367
+ `resources_pb2.EvalTestSetEntry`
1368
+
1369
+ * if return_format == array
1370
+ `Tuple(np.array, np.array, List[str], List[Input])`: Tuple has 4 elements (y, y_pred, concept_ids, inputs).
1371
+ y, y_pred, concept_ids can be used to compute metrics. 'inputs' can be use to download
1372
+ - if model is 'classifier': 'y' and 'y_pred' are both arrays with a shape of (num_inputs,)
1373
+ - if model is 'visual-detector': 'y' and 'y_pred' are arrays with a shape of (num_inputs,), where each element is array has shape (num_annotation, 6) consists of [x_min, y_min, x_max, y_max, concept_index, score]. The score is always 1 for 'y'
1374
+
1375
+ * if return_format == coco: Applicable only for 'visual-detector'
1376
+ `Tuple[List[Dict], List[Dict]]`: Tuple has 2 elemnts where first element is COCO Ground Truth and last one is COCO Prediction Annotation
1377
+
1378
+ Example Usages:
1379
+ ------
1380
+ * Evaluate `visual-classifier` using sklearn
1381
+
1382
+ ```python
1383
+ import os
1384
+ from sklearn.metrics import accuracy_score
1385
+ from sklearn.metrics import classification_report
1386
+ import numpy as np
1387
+ from clarifai.client.model import Model
1388
+ from clarifai.client.dataset import Dataset
1389
+ os.environ["CLARIFAI_PAT"] = "???"
1390
+ model = Model(url="url/of/model/includes/version-id")
1391
+ dataset = Dataset(dataset_id="dataset-id")
1392
+ y, y_pred, clss, input_protos = model.get_raw_eval(dataset, return_format="array")
1393
+ y = np.argmax(y, axis=1)
1394
+ y_pred = np.argmax(y_pred, axis=1)
1395
+ report = classification_report(y, y_pred, target_names=clss)
1396
+ print(report)
1397
+ acc = accuracy_score(y, y_pred)
1398
+ print("acc ", acc)
1399
+ ```
1400
+
1401
+ * Evaluate `visual-detector` using COCOeval
1402
+
1403
+ ```python
1404
+ import os
1405
+ import json
1406
+ from pycocotools.coco import COCO
1407
+ from pycocotools.cocoeval import COCOeval
1408
+ from clarifai.client.model import Model
1409
+ from clarifai.client.dataset import Dataset
1410
+ os.environ["CLARIFAI_PAT"] = "???" # Insert your PAT
1411
+ model = Model(url=model_url)
1412
+ dataset = Dataset(url=dataset_url)
1413
+ y, y_pred = model.get_raw_eval(dataset, return_format="coco")
1414
+ # save as files to load in COCO API
1415
+ def save_annot(d, path):
1416
+ with open(path, "w") as fp:
1417
+ json.dump(d, fp, indent=2)
1418
+ gt_path = os.path.join("gt.json")
1419
+ pred_path = os.path.join("pred.json")
1420
+ save_annot(y, gt_path)
1421
+ save_annot(y_pred, pred_path)
1422
+
1423
+ cocoGt = COCO(gt_path)
1424
+ cocoPred = COCO(pred_path)
1425
+ cocoEval = COCOeval(cocoGt, cocoPred, "bbox")
1426
+ cocoEval.evaluate()
1427
+ cocoEval.accumulate()
1428
+ cocoEval.summarize() # Print out result of all classes with all area type
1429
+ # Example:
1430
+ # Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.863
1431
+ # Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.973
1432
+ # Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.939
1433
+ # ...
1434
+ ```
1435
+
1436
+ """
1437
+ from clarifai.utils.evaluation.testset_annotation_parser import (
1438
+ parse_eval_annotation_classifier,
1439
+ parse_eval_annotation_detector,
1440
+ parse_eval_annotation_detector_coco,
1441
+ )
1442
+
1443
+ valid_model_types = ["visual-classifier", "text-classifier", "visual-detector"]
1444
+ supported_format = ['proto', 'array', 'coco']
1445
+ assert return_format in supported_format, ValueError(
1446
+ f"Expected return_format in {supported_format}, got {return_format}"
1447
+ )
1448
+ self.load_info()
1449
+ model_type_id = self.model_info.model_type_id
1450
+ assert model_type_id in valid_model_types, (
1451
+ f"This method only supports model types {valid_model_types}, but your model type is {self.model_info.model_type_id}."
1452
+ )
1453
+ assert not (dataset and eval_id), (
1454
+ "Using both `dataset` and `eval_id`, but only one should be passed."
1455
+ )
1456
+ assert not dataset or not eval_id, (
1457
+ "Please provide either `dataset` or `eval_id`, but nothing was passed."
1458
+ )
1459
+ if model_type_id.endswith("-classifier") and return_format == "coco":
1460
+ raise ValueError(
1461
+ f"return_format coco only applies for `visual-detector`, however your model is `{model_type_id}`"
1462
+ )
1463
+
1464
+ if dataset:
1465
+ eval_by_ds = self.get_eval_by_dataset(dataset)
1466
+ if len(eval_by_ds) == 0:
1467
+ raise Exception(f"Model is not valuated with dataset: {dataset}")
1468
+ eval_id = eval_by_ds[0].id
1469
+
1470
+ detail_eval_data = self.get_eval_by_id(
1471
+ eval_id=eval_id, test_set=True, metrics_by_class=True
1472
+ )
1473
+
1474
+ if return_format == "proto":
1475
+ return detail_eval_data.test_set
1476
+ elif model_type_id.endswith("-classifier"):
1477
+ return parse_eval_annotation_classifier(detail_eval_data)
1478
+ elif model_type_id == "visual-detector":
1479
+ if return_format == "array":
1480
+ return parse_eval_annotation_detector(detail_eval_data)
1481
+ elif return_format == "coco":
1482
+ return parse_eval_annotation_detector_coco(detail_eval_data)
1483
+
1484
+ def export(self, export_dir: str = None) -> None:
1485
+ """Export the model, stores the exported model as model.tar file
1486
+
1487
+ Args:
1488
+ export_dir (str, optional): If provided, the exported model will be saved in the specified directory else export status will be shown. Defaults to None.
1489
+
1490
+ Example:
1491
+ >>> from clarifai.client.model import Model
1492
+ >>> model = Model("url")
1493
+ >>> model.export()
1494
+ or
1495
+ >>> model.export('/path/to/export_model_dir')
1496
+ """
1497
+ assert self.model_info.model_version.id, (
1498
+ "Model version ID is missing. Please provide a `model_version` with a valid `id` as an argument or as a URL in the following format: '{user_id}/{app_id}/models/{your_model_id}/model_version_id/{your_version_model_id}' when initializing."
1499
+ )
1500
+ if export_dir:
1501
+ try:
1502
+ if not os.path.exists(export_dir):
1503
+ os.makedirs(export_dir)
1504
+ except OSError as e:
1505
+ raise Exception(f"An error occurred while creating the directory: {e}")
1506
+
1507
+ def _get_export_response():
1508
+ get_export_request = service_pb2.GetModelVersionExportRequest(
1509
+ user_app_id=self.user_app_id,
1510
+ model_id=self.id,
1511
+ version_id=self.model_info.model_version.id,
1512
+ )
1513
+ response = self._grpc_request(self.STUB.GetModelVersionExport, get_export_request)
1514
+
1515
+ if (
1516
+ response.status.code != status_code_pb2.SUCCESS
1517
+ and response.status.code != status_code_pb2.CONN_DOES_NOT_EXIST
1518
+ ):
1519
+ raise Exception(response.status)
1520
+
1521
+ return response
1522
+
1523
+ def _download_exported_model(
1524
+ get_model_export_response: service_pb2.SingleModelVersionExportResponse,
1525
+ local_filepath: str,
1526
+ ):
1527
+ model_export_url = get_model_export_response.export.url
1528
+ model_export_file_size = get_model_export_response.export.size
1529
+
1530
+ with open(local_filepath, 'wb') as f:
1531
+ progress = tqdm(
1532
+ total=model_export_file_size, unit='B', unit_scale=True, desc="Exporting model"
1533
+ )
1534
+ downloaded_size = 0
1535
+ range_size = RANGE_SIZE
1536
+ chunk_size = CHUNK_SIZE
1537
+ retry = False
1538
+ retry_count = 0
1539
+ while downloaded_size < model_export_file_size:
1540
+ if downloaded_size + range_size >= model_export_file_size:
1541
+ range_header = f"bytes={downloaded_size}-"
1542
+ else:
1543
+ range_header = (
1544
+ f"bytes={downloaded_size}-{(downloaded_size + range_size - 1)}"
1545
+ )
1546
+ try:
1547
+ session = requests.Session()
1548
+ retries = Retry(
1549
+ total=5, backoff_factor=0.1, status_forcelist=[500, 502, 503, 504]
1550
+ )
1551
+ session.mount('https://', HTTPAdapter(max_retries=retries))
1552
+ session.headers.update(
1553
+ {'Authorization': self.metadata[0][1], 'Range': range_header}
1554
+ )
1555
+ response = session.get(model_export_url, stream=True)
1556
+ response.raise_for_status()
1557
+
1558
+ for chunk in response.iter_content(chunk_size=chunk_size):
1559
+ f.write(chunk)
1560
+ progress.update(len(chunk))
1561
+ f.flush()
1562
+ os.fsync(f.fileno())
1563
+ downloaded_size += range_size
1564
+ if not retry:
1565
+ range_size = (
1566
+ (range_size * 2)
1567
+ if (range_size * 2) < MAX_RANGE_SIZE
1568
+ else MAX_RANGE_SIZE
1569
+ )
1570
+ chunk_size = (
1571
+ (chunk_size * 2)
1572
+ if (chunk_size * 2) < MAX_CHUNK_SIZE
1573
+ else MAX_CHUNK_SIZE
1574
+ )
1575
+ except Exception as e:
1576
+ self.logger.error(f"Error downloading model: {e}")
1577
+ range_size = (
1578
+ (range_size // 2)
1579
+ if (range_size // 2) > MIN_RANGE_SIZE
1580
+ else MIN_RANGE_SIZE
1581
+ )
1582
+ chunk_size = (
1583
+ (chunk_size // 2)
1584
+ if (chunk_size // 2) > MIN_CHUNK_SIZE
1585
+ else MIN_CHUNK_SIZE
1586
+ )
1587
+ retry = True
1588
+ retry_count += 1
1589
+ f.seek(downloaded_size)
1590
+ progress.reset(total=model_export_file_size)
1591
+ progress.update(downloaded_size)
1592
+ if retry_count > 5:
1593
+ break
1594
+ progress.close()
1595
+
1596
+ self.logger.info(
1597
+ f"Model ID {self.id} with version {self.model_info.model_version.id} exported successfully to {export_dir}/model.tar"
1598
+ )
1599
+
1600
+ get_export_response = _get_export_response()
1601
+ if get_export_response.status.code == status_code_pb2.CONN_DOES_NOT_EXIST:
1602
+ put_export_request = service_pb2.PutModelVersionExportsRequest(
1603
+ user_app_id=self.user_app_id,
1604
+ model_id=self.id,
1605
+ version_id=self.model_info.model_version.id,
1606
+ )
1607
+
1608
+ response = self._grpc_request(self.STUB.PutModelVersionExports, put_export_request)
1609
+ if response.status.code != status_code_pb2.SUCCESS:
1610
+ raise Exception(response.status)
1611
+
1612
+ self.logger.info(
1613
+ f"Export process has started for Model ID {self.id}, Version {self.model_info.model_version.id}"
1614
+ )
1615
+ if export_dir:
1616
+ start_time = time.time()
1617
+ backoff_iterator = BackoffIterator(10)
1618
+ while True:
1619
+ get_export_response = _get_export_response()
1620
+ if (
1621
+ get_export_response.export.status.code == status_code_pb2.MODEL_EXPORTING
1622
+ or get_export_response.export.status.code
1623
+ == status_code_pb2.MODEL_EXPORT_PENDING
1624
+ ) and time.time() - start_time < MODEL_EXPORT_TIMEOUT:
1625
+ self.logger.info(
1626
+ f"Export process is ongoing for Model ID {self.id}, Version {self.model_info.model_version.id}. Please wait..."
1627
+ )
1628
+ time.sleep(next(backoff_iterator))
1629
+ elif get_export_response.export.status.code == status_code_pb2.MODEL_EXPORTED:
1630
+ _download_exported_model(
1631
+ get_export_response, os.path.join(export_dir, "model.tar")
1632
+ )
1633
+ break
1634
+ elif time.time() - start_time > MODEL_EXPORT_TIMEOUT:
1635
+ raise Exception(
1636
+ f"""Model Export took too long. Please try again or contact support@clarifai.com
1637
+ Req ID: {get_export_response.status.req_id}"""
1638
+ )
1639
+ elif get_export_response.export.status.code == status_code_pb2.MODEL_EXPORTED:
1640
+ if export_dir:
1641
+ _download_exported_model(
1642
+ get_export_response, os.path.join(export_dir, "model.tar")
1643
+ )
1644
+ else:
1645
+ self.logger.info(
1646
+ f"Model ID {self.id} with version {self.model_info.model_version.id} is already exported, you can download it from the following URL: {get_export_response.export.url}"
1647
+ )
1648
+ elif (
1649
+ get_export_response.export.status.code == status_code_pb2.MODEL_EXPORTING
1650
+ or get_export_response.export.status.code == status_code_pb2.MODEL_EXPORT_PENDING
1651
+ ):
1451
1652
  self.logger.info(
1452
1653
  f"Export process is ongoing for Model ID {self.id}, Version {self.model_info.model_version.id}. Please wait..."
1453
1654
  )
1454
- time.sleep(next(backoff_iterator))
1455
- elif get_export_response.export.status.code == status_code_pb2.MODEL_EXPORTED:
1456
- _download_exported_model(get_export_response, os.path.join(export_dir, "model.tar"))
1457
- break
1458
- elif time.time() - start_time > MODEL_EXPORT_TIMEOUT:
1459
- raise Exception(
1460
- f"""Model Export took too long. Please try again or contact support@clarifai.com
1461
- Req ID: {get_export_response.status.req_id}""")
1462
- elif get_export_response.export.status.code == status_code_pb2.MODEL_EXPORTED:
1463
- if export_dir:
1464
- _download_exported_model(get_export_response, os.path.join(export_dir, "model.tar"))
1465
- else:
1655
+
1656
+ @staticmethod
1657
+ def _make_pretrained_config_proto(
1658
+ input_field_maps: dict, output_field_maps: dict, url: str = None
1659
+ ):
1660
+ """Make PretrainedModelConfig for uploading new version
1661
+
1662
+ Args:
1663
+ input_field_maps (dict): dict
1664
+ output_field_maps (dict): dict
1665
+ url (str, optional): direct download url. Defaults to None.
1666
+ """
1667
+
1668
+ def _parse_fields_map(x):
1669
+ """parse input, outputs to Struct"""
1670
+ _fields_map = Struct()
1671
+ _fields_map.update(x)
1672
+ return _fields_map
1673
+
1674
+ input_fields_map = _parse_fields_map(input_field_maps)
1675
+ output_fields_map = _parse_fields_map(output_field_maps)
1676
+
1677
+ return resources_pb2.PretrainedModelConfig(
1678
+ input_fields_map=input_fields_map,
1679
+ output_fields_map=output_fields_map,
1680
+ model_zip_url=url,
1681
+ )
1682
+
1683
+ @staticmethod
1684
+ def _make_inference_params_proto(
1685
+ inference_parameters: List[Dict],
1686
+ ) -> List[resources_pb2.ModelTypeField]:
1687
+ """Convert list of Clarifai inference parameters to proto for uploading new version
1688
+
1689
+ Args:
1690
+ inference_parameters (List[Dict]): Each dict has keys {field_type, path, default_value, description}
1691
+
1692
+ Returns:
1693
+ List[resources_pb2.ModelTypeField]
1694
+ """
1695
+
1696
+ def _make_default_value_proto(dtype, value):
1697
+ if dtype == 1:
1698
+ return Value(bool_value=value)
1699
+ elif dtype == 2 or dtype == 21:
1700
+ return Value(string_value=value)
1701
+ elif dtype == 3:
1702
+ return Value(number_value=value)
1703
+
1704
+ iterative_proto_params = []
1705
+ for param in inference_parameters:
1706
+ dtype = param.get("field_type")
1707
+ proto_param = resources_pb2.ModelTypeField(
1708
+ path=param.get("path"),
1709
+ field_type=dtype,
1710
+ default_value=_make_default_value_proto(
1711
+ dtype=dtype, value=param.get("default_value")
1712
+ ),
1713
+ description=param.get("description"),
1714
+ )
1715
+ iterative_proto_params.append(proto_param)
1716
+ return iterative_proto_params
1717
+
1718
+ def create_version_by_file(
1719
+ self,
1720
+ file_path: str,
1721
+ input_field_maps: dict,
1722
+ output_field_maps: dict,
1723
+ inference_parameter_configs: dict = None,
1724
+ model_version: str = None,
1725
+ part_id: int = 1,
1726
+ range_start: int = 0,
1727
+ no_cache: bool = False,
1728
+ no_resume: bool = False,
1729
+ description: str = "",
1730
+ ) -> 'Model':
1731
+ """Create model version by uploading local file
1732
+
1733
+ Args:
1734
+ file_path (str): path to built file.
1735
+ input_field_maps (dict): a dict where the key is clarifai input field and the value is triton model input,
1736
+ {clarifai_input_field: triton_input_filed}.
1737
+ output_field_maps (dict): a dict where the keys are clarifai output fields and the values are triton model outputs,
1738
+ {clarifai_output_field1: triton_output_filed1, clarifai_output_field2: triton_output_filed2,...}.
1739
+ inference_parameter_configs (List[dict]): list of dicts - keys are path, field_type, default_value, description. Default is None
1740
+ model_version (str, optional): Custom model version. Defaults to None.
1741
+ part_id (int, optional): part id of file. Defaults to 1.
1742
+ range_start (int, optional): range of uploaded size. Defaults to 0.
1743
+ no_cache (bool, optional): not saving uploading cache that is used to resume uploading. Defaults to False.
1744
+ no_resume (bool, optional): disable auto resume upload. Defaults to False.
1745
+ description (str): Model description.
1746
+
1747
+ Return:
1748
+ Model: instance of Model with new created version
1749
+
1750
+ """
1751
+ file_size = os.path.getsize(file_path)
1752
+ assert MIN_CHUNK_FOR_UPLOAD_FILE <= file_size <= MAX_CHUNK_FOR_UPLOAD_FILE, (
1753
+ "The file size exceeds the allowable limit, which ranges from 5MiB to 5GiB."
1754
+ )
1755
+
1756
+ pretrained_proto = Model._make_pretrained_config_proto(
1757
+ input_field_maps=input_field_maps, output_field_maps=output_field_maps
1758
+ )
1759
+ inference_param_proto = (
1760
+ Model._make_inference_params_proto(inference_parameter_configs)
1761
+ if inference_parameter_configs
1762
+ else None
1763
+ )
1764
+
1765
+ if file_size >= 1e9:
1766
+ chunk_size = 1024 * 50_000 # 50MB
1767
+ else:
1768
+ chunk_size = 1024 * 10_000 # 10MB
1769
+
1770
+ # self.logger.info(f"Chunk {chunk_size/1e6}MB, {file_size/chunk_size} steps")
1771
+ # self.logger.info(f" Max bytes per stream {MAX_SIZE_PER_STREAM}")
1772
+
1773
+ cache_dir = os.path.join(file_path, '..', '.cache')
1774
+ cache_upload_file = os.path.join(cache_dir, "upload.json")
1775
+ last_percent = 0
1776
+ if os.path.exists(cache_upload_file) and not no_resume:
1777
+ with open(cache_upload_file, "r") as fp:
1778
+ try:
1779
+ cache_info = json.load(fp)
1780
+ if isinstance(cache_info, dict):
1781
+ part_id = cache_info.get("part_id", part_id)
1782
+ chunk_size = cache_info.get("chunk_size", chunk_size)
1783
+ range_start = cache_info.get("range_start", range_start)
1784
+ model_version = cache_info.get("model_version", model_version)
1785
+ last_percent = cache_info.get("last_percent", last_percent)
1786
+ except Exception as e:
1787
+ self.logger.error(f"Skipping loading the upload cache due to error {e}.")
1788
+
1789
+ def init_model_version_upload(model_version):
1790
+ return service_pb2.PostModelVersionsUploadRequest(
1791
+ upload_config=service_pb2.PostModelVersionsUploadConfig(
1792
+ user_app_id=self.user_app_id,
1793
+ model_id=self.id,
1794
+ total_size=file_size,
1795
+ model_version=resources_pb2.ModelVersion(
1796
+ id=model_version,
1797
+ pretrained_model_config=pretrained_proto,
1798
+ description=description,
1799
+ output_info=resources_pb2.OutputInfo(params_specs=inference_param_proto),
1800
+ ),
1801
+ )
1802
+ )
1803
+
1804
+ def _uploading(chunk, part_id, range_start, model_version):
1805
+ return service_pb2.PostModelVersionsUploadRequest(
1806
+ content_part=resources_pb2.UploadContentPart(
1807
+ data=chunk, part_number=part_id, range_start=range_start
1808
+ )
1809
+ )
1810
+
1811
+ finished_status = [status_code_pb2.SUCCESS, status_code_pb2.UPLOAD_DONE]
1812
+ uploading_in_progress_status = [
1813
+ status_code_pb2.UPLOAD_IN_PROGRESS,
1814
+ status_code_pb2.MODEL_UPLOADING,
1815
+ ]
1816
+
1817
+ def _save_cache(cache: dict):
1818
+ if not no_cache:
1819
+ os.makedirs(cache_dir, exist_ok=True)
1820
+ with open(cache_upload_file, "w") as fp:
1821
+ json.dump(cache, fp, indent=2)
1822
+
1823
+ def stream_request(fp, part_id, end_part_id, chunk_size, version):
1824
+ yield init_model_version_upload(version)
1825
+ for iter_part_id in range(part_id, end_part_id):
1826
+ chunk = fp.read(chunk_size)
1827
+ if not chunk:
1828
+ return
1829
+ yield _uploading(
1830
+ chunk=chunk,
1831
+ part_id=iter_part_id,
1832
+ range_start=chunk_size * (iter_part_id - 1),
1833
+ model_version=version,
1834
+ )
1835
+
1836
+ tqdm_loader = tqdm(total=100)
1837
+ if model_version:
1838
+ desc = f"Uploading model `{self.id}` version `{model_version}` ..."
1839
+ else:
1840
+ desc = f"Uploading model `{self.id}` ..."
1841
+ tqdm_loader.set_description(desc)
1842
+
1843
+ cache_uploading_info = {}
1844
+ cache_uploading_info["part_id"] = part_id
1845
+ cache_uploading_info["model_version"] = model_version
1846
+ cache_uploading_info["range_start"] = range_start
1847
+ cache_uploading_info["chunk_size"] = chunk_size
1848
+ cache_uploading_info["last_percent"] = last_percent
1849
+ tqdm_loader.update(last_percent)
1850
+ last_part_id = part_id
1851
+ n_chunks = file_size // chunk_size
1852
+ n_chunk_per_stream = MAX_SIZE_PER_STREAM // chunk_size or 1
1853
+
1854
+ def stream_and_logging(
1855
+ request, tqdm_loader, cache_uploading_info, expected_steps: int = None
1856
+ ):
1857
+ for st_step, st_response in enumerate(
1858
+ self.auth_helper.get_stub().PostModelVersionsUpload(
1859
+ request, metadata=self.auth_helper.metadata
1860
+ )
1861
+ ):
1862
+ if st_response.status.code in uploading_in_progress_status:
1863
+ if cache_uploading_info["model_version"]:
1864
+ assert (
1865
+ st_response.model_version_id == cache_uploading_info["model_version"]
1866
+ ), RuntimeError
1867
+ else:
1868
+ cache_uploading_info["model_version"] = st_response.model_version_id
1869
+ if st_step > 0:
1870
+ cache_uploading_info["part_id"] += 1
1871
+ cache_uploading_info["range_start"] += chunk_size
1872
+ _save_cache(cache_uploading_info)
1873
+
1874
+ if st_response.status.percent_completed:
1875
+ step_percent = (
1876
+ st_response.status.percent_completed
1877
+ - cache_uploading_info["last_percent"]
1878
+ )
1879
+ cache_uploading_info["last_percent"] += step_percent
1880
+ tqdm_loader.set_description(
1881
+ f"{st_response.status.description}, {st_response.status.details}, version id {cache_uploading_info.get('model_version')}"
1882
+ )
1883
+ tqdm_loader.update(step_percent)
1884
+ elif st_response.status.code not in finished_status + uploading_in_progress_status:
1885
+ # TODO: Find better way to handle error
1886
+ if expected_steps and st_step < expected_steps:
1887
+ raise Exception(f"Failed to upload model, error: {st_response.status}")
1888
+
1889
+ with open(file_path, 'rb') as fp:
1890
+ # seeking
1891
+ for _ in range(1, last_part_id):
1892
+ fp.read(chunk_size)
1893
+ # Stream even part
1894
+ end_part_id = n_chunks or 1
1895
+ for iter_part_id in range(int(last_part_id), int(n_chunks), int(n_chunk_per_stream)):
1896
+ end_part_id = iter_part_id + n_chunk_per_stream
1897
+ end_part_id = min(n_chunks, end_part_id)
1898
+ expected_steps = end_part_id - iter_part_id + 1 # init step
1899
+ st_reqs = stream_request(
1900
+ fp,
1901
+ iter_part_id,
1902
+ end_part_id=end_part_id,
1903
+ chunk_size=chunk_size,
1904
+ version=cache_uploading_info["model_version"],
1905
+ )
1906
+ stream_and_logging(st_reqs, tqdm_loader, cache_uploading_info, expected_steps)
1907
+ # Stream last part
1908
+ accum_size = (end_part_id - 1) * chunk_size
1909
+ remained_size = file_size - accum_size if accum_size >= 0 else file_size
1910
+ st_reqs = stream_request(
1911
+ fp,
1912
+ end_part_id,
1913
+ end_part_id=end_part_id + 1,
1914
+ chunk_size=remained_size,
1915
+ version=cache_uploading_info["model_version"],
1916
+ )
1917
+ stream_and_logging(st_reqs, tqdm_loader, cache_uploading_info, 2)
1918
+
1919
+ # clean up cache
1920
+ if not no_cache:
1921
+ try:
1922
+ os.remove(cache_upload_file)
1923
+ except Exception:
1924
+ _save_cache({})
1925
+
1926
+ if cache_uploading_info["last_percent"] <= 100:
1927
+ tqdm_loader.update(100 - cache_uploading_info["last_percent"])
1928
+ tqdm_loader.set_description("Upload done")
1929
+
1930
+ tqdm_loader.set_description(
1931
+ f"Success uploading model {self.id}, new version {cache_uploading_info.get('model_version')}"
1932
+ )
1933
+
1934
+ return Model.from_auth_helper(
1935
+ auth=self.auth_helper,
1936
+ model_id=self.id,
1937
+ model_version=dict(id=cache_uploading_info.get('model_version')),
1938
+ )
1939
+
1940
+ def create_version_by_url(
1941
+ self,
1942
+ url: str,
1943
+ input_field_maps: dict,
1944
+ output_field_maps: dict,
1945
+ inference_parameter_configs: List[dict] = None,
1946
+ description: str = "",
1947
+ ) -> 'Model':
1948
+ """Upload a new version of an existing model in the Clarifai platform using direct download url.
1949
+
1950
+ Args:
1951
+ url (str]): url of zip of model
1952
+ input_field_maps (dict): a dict where the key is clarifai input field and the value is triton model input,
1953
+ {clarifai_input_field: triton_input_filed}.
1954
+ output_field_maps (dict): a dict where the keys are clarifai output fields and the values are triton model outputs,
1955
+ {clarifai_output_field1: triton_output_filed1, clarifai_output_field2: triton_output_filed2,...}.
1956
+ inference_parameter_configs (List[dict]): list of dicts - keys are path, field_type, default_value, description. Default is None
1957
+ description (str): Model description.
1958
+
1959
+ Return:
1960
+ Model: instance of Model with new created version
1961
+ """
1962
+
1963
+ pretrained_proto = Model._make_pretrained_config_proto(
1964
+ input_field_maps=input_field_maps, output_field_maps=output_field_maps, url=url
1965
+ )
1966
+ inference_param_proto = (
1967
+ Model._make_inference_params_proto(inference_parameter_configs)
1968
+ if inference_parameter_configs
1969
+ else None
1970
+ )
1971
+ request = service_pb2.PostModelVersionsRequest(
1972
+ user_app_id=self.user_app_id,
1973
+ model_id=self.id,
1974
+ model_versions=[
1975
+ resources_pb2.ModelVersion(
1976
+ pretrained_model_config=pretrained_proto,
1977
+ description=description,
1978
+ output_info=resources_pb2.OutputInfo(params_specs=inference_param_proto),
1979
+ )
1980
+ ],
1981
+ )
1982
+ response = self._grpc_request(self.STUB.PostModelVersions, request)
1983
+
1984
+ if response.status.code != status_code_pb2.SUCCESS:
1985
+ raise Exception(f"Failed to upload model, error: {response.status}")
1466
1986
  self.logger.info(
1467
- f"Model ID {self.id} with version {self.model_info.model_version.id} is already exported, you can download it from the following URL: {get_export_response.export.url}"
1987
+ f"Success uploading model {self.id}, new version {response.model.model_version.id}"
1988
+ )
1989
+
1990
+ return Model.from_auth_helper(
1991
+ auth=self.auth_helper,
1992
+ model_id=self.id,
1993
+ model_version=dict(id=response.model.model_version.id),
1468
1994
  )
1469
- elif get_export_response.export.status.code == status_code_pb2.MODEL_EXPORTING or \
1470
- get_export_response.export.status.code == status_code_pb2.MODEL_EXPORT_PENDING:
1471
- self.logger.info(
1472
- f"Export process is ongoing for Model ID {self.id}, Version {self.model_info.model_version.id}. Please wait..."
1473
- )
1474
-
1475
- @staticmethod
1476
- def _make_pretrained_config_proto(input_field_maps: dict,
1477
- output_field_maps: dict,
1478
- url: str = None):
1479
- """Make PretrainedModelConfig for uploading new version
1480
-
1481
- Args:
1482
- input_field_maps (dict): dict
1483
- output_field_maps (dict): dict
1484
- url (str, optional): direct download url. Defaults to None.
1485
- """
1486
-
1487
- def _parse_fields_map(x):
1488
- """parse input, outputs to Struct"""
1489
- _fields_map = Struct()
1490
- _fields_map.update(x)
1491
- return _fields_map
1492
-
1493
- input_fields_map = _parse_fields_map(input_field_maps)
1494
- output_fields_map = _parse_fields_map(output_field_maps)
1495
-
1496
- return resources_pb2.PretrainedModelConfig(
1497
- input_fields_map=input_fields_map, output_fields_map=output_fields_map, model_zip_url=url)
1498
-
1499
- @staticmethod
1500
- def _make_inference_params_proto(
1501
- inference_parameters: List[Dict]) -> List[resources_pb2.ModelTypeField]:
1502
- """Convert list of Clarifai inference parameters to proto for uploading new version
1503
-
1504
- Args:
1505
- inference_parameters (List[Dict]): Each dict has keys {field_type, path, default_value, description}
1506
-
1507
- Returns:
1508
- List[resources_pb2.ModelTypeField]
1509
- """
1510
-
1511
- def _make_default_value_proto(dtype, value):
1512
- if dtype == 1:
1513
- return Value(bool_value=value)
1514
- elif dtype == 2 or dtype == 21:
1515
- return Value(string_value=value)
1516
- elif dtype == 3:
1517
- return Value(number_value=value)
1518
-
1519
- iterative_proto_params = []
1520
- for param in inference_parameters:
1521
- dtype = param.get("field_type")
1522
- proto_param = resources_pb2.ModelTypeField(
1523
- path=param.get("path"),
1524
- field_type=dtype,
1525
- default_value=_make_default_value_proto(dtype=dtype, value=param.get("default_value")),
1526
- description=param.get("description"),
1527
- )
1528
- iterative_proto_params.append(proto_param)
1529
- return iterative_proto_params
1530
-
1531
- def create_version_by_file(self,
1532
- file_path: str,
1533
- input_field_maps: dict,
1534
- output_field_maps: dict,
1535
- inference_parameter_configs: dict = None,
1536
- model_version: str = None,
1537
- part_id: int = 1,
1538
- range_start: int = 0,
1539
- no_cache: bool = False,
1540
- no_resume: bool = False,
1541
- description: str = "") -> 'Model':
1542
- """Create model version by uploading local file
1543
-
1544
- Args:
1545
- file_path (str): path to built file.
1546
- input_field_maps (dict): a dict where the key is clarifai input field and the value is triton model input,
1547
- {clarifai_input_field: triton_input_filed}.
1548
- output_field_maps (dict): a dict where the keys are clarifai output fields and the values are triton model outputs,
1549
- {clarifai_output_field1: triton_output_filed1, clarifai_output_field2: triton_output_filed2,...}.
1550
- inference_parameter_configs (List[dict]): list of dicts - keys are path, field_type, default_value, description. Default is None
1551
- model_version (str, optional): Custom model version. Defaults to None.
1552
- part_id (int, optional): part id of file. Defaults to 1.
1553
- range_start (int, optional): range of uploaded size. Defaults to 0.
1554
- no_cache (bool, optional): not saving uploading cache that is used to resume uploading. Defaults to False.
1555
- no_resume (bool, optional): disable auto resume upload. Defaults to False.
1556
- description (str): Model description.
1557
-
1558
- Return:
1559
- Model: instance of Model with new created version
1560
-
1561
- """
1562
- file_size = os.path.getsize(file_path)
1563
- assert MIN_CHUNK_FOR_UPLOAD_FILE <= file_size <= MAX_CHUNK_FOR_UPLOAD_FILE, "The file size exceeds the allowable limit, which ranges from 5MiB to 5GiB."
1564
-
1565
- pretrained_proto = Model._make_pretrained_config_proto(
1566
- input_field_maps=input_field_maps, output_field_maps=output_field_maps)
1567
- inference_param_proto = Model._make_inference_params_proto(
1568
- inference_parameter_configs) if inference_parameter_configs else None
1569
-
1570
- if file_size >= 1e9:
1571
- chunk_size = 1024 * 50_000 # 50MB
1572
- else:
1573
- chunk_size = 1024 * 10_000 # 10MB
1574
-
1575
- #self.logger.info(f"Chunk {chunk_size/1e6}MB, {file_size/chunk_size} steps")
1576
- #self.logger.info(f" Max bytes per stream {MAX_SIZE_PER_STREAM}")
1577
-
1578
- cache_dir = os.path.join(file_path, '..', '.cache')
1579
- cache_upload_file = os.path.join(cache_dir, "upload.json")
1580
- last_percent = 0
1581
- if os.path.exists(cache_upload_file) and not no_resume:
1582
- with open(cache_upload_file, "r") as fp:
1583
- try:
1584
- cache_info = json.load(fp)
1585
- if isinstance(cache_info, dict):
1586
- part_id = cache_info.get("part_id", part_id)
1587
- chunk_size = cache_info.get("chunk_size", chunk_size)
1588
- range_start = cache_info.get("range_start", range_start)
1589
- model_version = cache_info.get("model_version", model_version)
1590
- last_percent = cache_info.get("last_percent", last_percent)
1591
- except Exception as e:
1592
- self.logger.error(f"Skipping loading the upload cache due to error {e}.")
1593
-
1594
- def init_model_version_upload(model_version):
1595
- return service_pb2.PostModelVersionsUploadRequest(
1596
- upload_config=service_pb2.PostModelVersionsUploadConfig(
1597
- user_app_id=self.user_app_id,
1598
- model_id=self.id,
1599
- total_size=file_size,
1600
- model_version=resources_pb2.ModelVersion(
1601
- id=model_version,
1602
- pretrained_model_config=pretrained_proto,
1603
- description=description,
1604
- output_info=resources_pb2.OutputInfo(params_specs=inference_param_proto)),
1605
- ))
1606
-
1607
- def _uploading(chunk, part_id, range_start, model_version):
1608
- return service_pb2.PostModelVersionsUploadRequest(
1609
- content_part=resources_pb2.UploadContentPart(
1610
- data=chunk, part_number=part_id, range_start=range_start))
1611
-
1612
- finished_status = [status_code_pb2.SUCCESS, status_code_pb2.UPLOAD_DONE]
1613
- uploading_in_progress_status = [
1614
- status_code_pb2.UPLOAD_IN_PROGRESS, status_code_pb2.MODEL_UPLOADING
1615
- ]
1616
-
1617
- def _save_cache(cache: dict):
1618
- if not no_cache:
1619
- os.makedirs(cache_dir, exist_ok=True)
1620
- with open(cache_upload_file, "w") as fp:
1621
- json.dump(cache, fp, indent=2)
1622
-
1623
- def stream_request(fp, part_id, end_part_id, chunk_size, version):
1624
- yield init_model_version_upload(version)
1625
- for iter_part_id in range(part_id, end_part_id):
1626
- chunk = fp.read(chunk_size)
1627
- if not chunk:
1628
- return
1629
- yield _uploading(
1630
- chunk=chunk,
1631
- part_id=iter_part_id,
1632
- range_start=chunk_size * (iter_part_id - 1),
1633
- model_version=version)
1634
-
1635
- tqdm_loader = tqdm(total=100)
1636
- if model_version:
1637
- desc = f"Uploading model `{self.id}` version `{model_version}` ..."
1638
- else:
1639
- desc = f"Uploading model `{self.id}` ..."
1640
- tqdm_loader.set_description(desc)
1641
-
1642
- cache_uploading_info = {}
1643
- cache_uploading_info["part_id"] = part_id
1644
- cache_uploading_info["model_version"] = model_version
1645
- cache_uploading_info["range_start"] = range_start
1646
- cache_uploading_info["chunk_size"] = chunk_size
1647
- cache_uploading_info["last_percent"] = last_percent
1648
- tqdm_loader.update(last_percent)
1649
- last_part_id = part_id
1650
- n_chunks = file_size // chunk_size
1651
- n_chunk_per_stream = MAX_SIZE_PER_STREAM // chunk_size or 1
1652
-
1653
- def stream_and_logging(request, tqdm_loader, cache_uploading_info, expected_steps: int = None):
1654
- for st_step, st_response in enumerate(self.auth_helper.get_stub().PostModelVersionsUpload(
1655
- request, metadata=self.auth_helper.metadata)):
1656
- if st_response.status.code in uploading_in_progress_status:
1657
- if cache_uploading_info["model_version"]:
1658
- assert st_response.model_version_id == cache_uploading_info[
1659
- "model_version"], RuntimeError
1660
- else:
1661
- cache_uploading_info["model_version"] = st_response.model_version_id
1662
- if st_step > 0:
1663
- cache_uploading_info["part_id"] += 1
1664
- cache_uploading_info["range_start"] += chunk_size
1665
- _save_cache(cache_uploading_info)
1666
-
1667
- if st_response.status.percent_completed:
1668
- step_percent = st_response.status.percent_completed - cache_uploading_info["last_percent"]
1669
- cache_uploading_info["last_percent"] += step_percent
1670
- tqdm_loader.set_description(
1671
- f"{st_response.status.description}, {st_response.status.details}, version id {cache_uploading_info.get('model_version')}"
1672
- )
1673
- tqdm_loader.update(step_percent)
1674
- elif st_response.status.code not in finished_status + uploading_in_progress_status:
1675
- # TODO: Find better way to handle error
1676
- if expected_steps and st_step < expected_steps:
1677
- raise Exception(f"Failed to upload model, error: {st_response.status}")
1678
-
1679
- with open(file_path, 'rb') as fp:
1680
- # seeking
1681
- for _ in range(1, last_part_id):
1682
- fp.read(chunk_size)
1683
- # Stream even part
1684
- end_part_id = n_chunks or 1
1685
- for iter_part_id in range(int(last_part_id), int(n_chunks), int(n_chunk_per_stream)):
1686
- end_part_id = iter_part_id + n_chunk_per_stream
1687
- if end_part_id >= n_chunks:
1688
- end_part_id = n_chunks
1689
- expected_steps = end_part_id - iter_part_id + 1 # init step
1690
- st_reqs = stream_request(
1691
- fp,
1692
- iter_part_id,
1693
- end_part_id=end_part_id,
1694
- chunk_size=chunk_size,
1695
- version=cache_uploading_info["model_version"])
1696
- stream_and_logging(st_reqs, tqdm_loader, cache_uploading_info, expected_steps)
1697
- # Stream last part
1698
- accum_size = (end_part_id - 1) * chunk_size
1699
- remained_size = file_size - accum_size if accum_size >= 0 else file_size
1700
- st_reqs = stream_request(
1701
- fp,
1702
- end_part_id,
1703
- end_part_id=end_part_id + 1,
1704
- chunk_size=remained_size,
1705
- version=cache_uploading_info["model_version"])
1706
- stream_and_logging(st_reqs, tqdm_loader, cache_uploading_info, 2)
1707
-
1708
- # clean up cache
1709
- if not no_cache:
1710
- try:
1711
- os.remove(cache_upload_file)
1712
- except Exception:
1713
- _save_cache({})
1714
-
1715
- if cache_uploading_info["last_percent"] <= 100:
1716
- tqdm_loader.update(100 - cache_uploading_info["last_percent"])
1717
- tqdm_loader.set_description("Upload done")
1718
-
1719
- tqdm_loader.set_description(
1720
- f"Success uploading model {self.id}, new version {cache_uploading_info.get('model_version')}"
1721
- )
1722
-
1723
- return Model.from_auth_helper(
1724
- auth=self.auth_helper,
1725
- model_id=self.id,
1726
- model_version=dict(id=cache_uploading_info.get('model_version')))
1727
-
1728
- def create_version_by_url(self,
1729
- url: str,
1730
- input_field_maps: dict,
1731
- output_field_maps: dict,
1732
- inference_parameter_configs: List[dict] = None,
1733
- description: str = "") -> 'Model':
1734
- """Upload a new version of an existing model in the Clarifai platform using direct download url.
1735
-
1736
- Args:
1737
- url (str]): url of zip of model
1738
- input_field_maps (dict): a dict where the key is clarifai input field and the value is triton model input,
1739
- {clarifai_input_field: triton_input_filed}.
1740
- output_field_maps (dict): a dict where the keys are clarifai output fields and the values are triton model outputs,
1741
- {clarifai_output_field1: triton_output_filed1, clarifai_output_field2: triton_output_filed2,...}.
1742
- inference_parameter_configs (List[dict]): list of dicts - keys are path, field_type, default_value, description. Default is None
1743
- description (str): Model description.
1744
-
1745
- Return:
1746
- Model: instance of Model with new created version
1747
- """
1748
-
1749
- pretrained_proto = Model._make_pretrained_config_proto(
1750
- input_field_maps=input_field_maps, output_field_maps=output_field_maps, url=url)
1751
- inference_param_proto = Model._make_inference_params_proto(
1752
- inference_parameter_configs) if inference_parameter_configs else None
1753
- request = service_pb2.PostModelVersionsRequest(
1754
- user_app_id=self.user_app_id,
1755
- model_id=self.id,
1756
- model_versions=[
1757
- resources_pb2.ModelVersion(
1758
- pretrained_model_config=pretrained_proto,
1759
- description=description,
1760
- output_info=resources_pb2.OutputInfo(params_specs=inference_param_proto))
1761
- ])
1762
- response = self._grpc_request(self.STUB.PostModelVersions, request)
1763
-
1764
- if response.status.code != status_code_pb2.SUCCESS:
1765
- raise Exception(f"Failed to upload model, error: {response.status}")
1766
- self.logger.info(
1767
- f"Success uploading model {self.id}, new version {response.model.model_version.id}")
1768
-
1769
- return Model.from_auth_helper(
1770
- auth=self.auth_helper,
1771
- model_id=self.id,
1772
- model_version=dict(id=response.model.model_version.id))