clarifai 11.3.0rc2__py3-none-any.whl → 11.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (300) hide show
  1. clarifai/__init__.py +1 -1
  2. clarifai/cli/__main__.py +1 -1
  3. clarifai/cli/base.py +144 -136
  4. clarifai/cli/compute_cluster.py +45 -31
  5. clarifai/cli/deployment.py +93 -76
  6. clarifai/cli/model.py +578 -180
  7. clarifai/cli/nodepool.py +100 -82
  8. clarifai/client/__init__.py +12 -2
  9. clarifai/client/app.py +973 -911
  10. clarifai/client/auth/helper.py +345 -342
  11. clarifai/client/auth/register.py +7 -7
  12. clarifai/client/auth/stub.py +107 -106
  13. clarifai/client/base.py +185 -178
  14. clarifai/client/compute_cluster.py +214 -180
  15. clarifai/client/dataset.py +793 -698
  16. clarifai/client/deployment.py +55 -50
  17. clarifai/client/input.py +1223 -1088
  18. clarifai/client/lister.py +47 -45
  19. clarifai/client/model.py +1939 -1717
  20. clarifai/client/model_client.py +525 -502
  21. clarifai/client/module.py +82 -73
  22. clarifai/client/nodepool.py +358 -213
  23. clarifai/client/runner.py +58 -0
  24. clarifai/client/search.py +342 -309
  25. clarifai/client/user.py +419 -414
  26. clarifai/client/workflow.py +294 -274
  27. clarifai/constants/dataset.py +11 -17
  28. clarifai/constants/model.py +8 -2
  29. clarifai/datasets/export/inputs_annotations.py +233 -217
  30. clarifai/datasets/upload/base.py +63 -51
  31. clarifai/datasets/upload/features.py +43 -38
  32. clarifai/datasets/upload/image.py +237 -207
  33. clarifai/datasets/upload/loaders/coco_captions.py +34 -32
  34. clarifai/datasets/upload/loaders/coco_detection.py +72 -65
  35. clarifai/datasets/upload/loaders/imagenet_classification.py +57 -53
  36. clarifai/datasets/upload/loaders/xview_detection.py +274 -132
  37. clarifai/datasets/upload/multimodal.py +55 -46
  38. clarifai/datasets/upload/text.py +55 -47
  39. clarifai/datasets/upload/utils.py +250 -234
  40. clarifai/errors.py +51 -50
  41. clarifai/models/api.py +260 -238
  42. clarifai/modules/css.py +50 -50
  43. clarifai/modules/pages.py +33 -33
  44. clarifai/rag/rag.py +312 -288
  45. clarifai/rag/utils.py +91 -84
  46. clarifai/runners/models/model_builder.py +906 -802
  47. clarifai/runners/models/model_class.py +370 -331
  48. clarifai/runners/models/model_run_locally.py +459 -419
  49. clarifai/runners/models/model_runner.py +170 -162
  50. clarifai/runners/models/model_servicer.py +78 -70
  51. clarifai/runners/server.py +111 -101
  52. clarifai/runners/utils/code_script.py +225 -187
  53. clarifai/runners/utils/const.py +4 -1
  54. clarifai/runners/utils/data_types/__init__.py +12 -0
  55. clarifai/runners/utils/data_types/data_types.py +598 -0
  56. clarifai/runners/utils/data_utils.py +387 -440
  57. clarifai/runners/utils/loader.py +247 -227
  58. clarifai/runners/utils/method_signatures.py +411 -386
  59. clarifai/runners/utils/openai_convertor.py +108 -109
  60. clarifai/runners/utils/serializers.py +175 -179
  61. clarifai/runners/utils/url_fetcher.py +35 -35
  62. clarifai/schema/search.py +56 -63
  63. clarifai/urls/helper.py +125 -102
  64. clarifai/utils/cli.py +129 -123
  65. clarifai/utils/config.py +127 -87
  66. clarifai/utils/constants.py +49 -0
  67. clarifai/utils/evaluation/helpers.py +503 -466
  68. clarifai/utils/evaluation/main.py +431 -393
  69. clarifai/utils/evaluation/testset_annotation_parser.py +154 -144
  70. clarifai/utils/logging.py +324 -306
  71. clarifai/utils/misc.py +60 -56
  72. clarifai/utils/model_train.py +165 -146
  73. clarifai/utils/protobuf.py +126 -103
  74. clarifai/versions.py +3 -1
  75. clarifai/workflows/export.py +48 -50
  76. clarifai/workflows/utils.py +39 -36
  77. clarifai/workflows/validate.py +55 -43
  78. {clarifai-11.3.0rc2.dist-info → clarifai-11.4.0.dist-info}/METADATA +16 -6
  79. clarifai-11.4.0.dist-info/RECORD +109 -0
  80. {clarifai-11.3.0rc2.dist-info → clarifai-11.4.0.dist-info}/WHEEL +1 -1
  81. clarifai/__pycache__/__init__.cpython-310.pyc +0 -0
  82. clarifai/__pycache__/__init__.cpython-311.pyc +0 -0
  83. clarifai/__pycache__/__init__.cpython-39.pyc +0 -0
  84. clarifai/__pycache__/errors.cpython-310.pyc +0 -0
  85. clarifai/__pycache__/errors.cpython-311.pyc +0 -0
  86. clarifai/__pycache__/versions.cpython-310.pyc +0 -0
  87. clarifai/__pycache__/versions.cpython-311.pyc +0 -0
  88. clarifai/cli/__pycache__/__init__.cpython-310.pyc +0 -0
  89. clarifai/cli/__pycache__/__init__.cpython-311.pyc +0 -0
  90. clarifai/cli/__pycache__/base.cpython-310.pyc +0 -0
  91. clarifai/cli/__pycache__/base.cpython-311.pyc +0 -0
  92. clarifai/cli/__pycache__/base_cli.cpython-310.pyc +0 -0
  93. clarifai/cli/__pycache__/compute_cluster.cpython-310.pyc +0 -0
  94. clarifai/cli/__pycache__/compute_cluster.cpython-311.pyc +0 -0
  95. clarifai/cli/__pycache__/deployment.cpython-310.pyc +0 -0
  96. clarifai/cli/__pycache__/deployment.cpython-311.pyc +0 -0
  97. clarifai/cli/__pycache__/model.cpython-310.pyc +0 -0
  98. clarifai/cli/__pycache__/model.cpython-311.pyc +0 -0
  99. clarifai/cli/__pycache__/model_cli.cpython-310.pyc +0 -0
  100. clarifai/cli/__pycache__/nodepool.cpython-310.pyc +0 -0
  101. clarifai/cli/__pycache__/nodepool.cpython-311.pyc +0 -0
  102. clarifai/client/__pycache__/__init__.cpython-310.pyc +0 -0
  103. clarifai/client/__pycache__/__init__.cpython-311.pyc +0 -0
  104. clarifai/client/__pycache__/__init__.cpython-39.pyc +0 -0
  105. clarifai/client/__pycache__/app.cpython-310.pyc +0 -0
  106. clarifai/client/__pycache__/app.cpython-311.pyc +0 -0
  107. clarifai/client/__pycache__/app.cpython-39.pyc +0 -0
  108. clarifai/client/__pycache__/base.cpython-310.pyc +0 -0
  109. clarifai/client/__pycache__/base.cpython-311.pyc +0 -0
  110. clarifai/client/__pycache__/compute_cluster.cpython-310.pyc +0 -0
  111. clarifai/client/__pycache__/compute_cluster.cpython-311.pyc +0 -0
  112. clarifai/client/__pycache__/dataset.cpython-310.pyc +0 -0
  113. clarifai/client/__pycache__/dataset.cpython-311.pyc +0 -0
  114. clarifai/client/__pycache__/deployment.cpython-310.pyc +0 -0
  115. clarifai/client/__pycache__/deployment.cpython-311.pyc +0 -0
  116. clarifai/client/__pycache__/input.cpython-310.pyc +0 -0
  117. clarifai/client/__pycache__/input.cpython-311.pyc +0 -0
  118. clarifai/client/__pycache__/lister.cpython-310.pyc +0 -0
  119. clarifai/client/__pycache__/lister.cpython-311.pyc +0 -0
  120. clarifai/client/__pycache__/model.cpython-310.pyc +0 -0
  121. clarifai/client/__pycache__/model.cpython-311.pyc +0 -0
  122. clarifai/client/__pycache__/module.cpython-310.pyc +0 -0
  123. clarifai/client/__pycache__/module.cpython-311.pyc +0 -0
  124. clarifai/client/__pycache__/nodepool.cpython-310.pyc +0 -0
  125. clarifai/client/__pycache__/nodepool.cpython-311.pyc +0 -0
  126. clarifai/client/__pycache__/search.cpython-310.pyc +0 -0
  127. clarifai/client/__pycache__/search.cpython-311.pyc +0 -0
  128. clarifai/client/__pycache__/user.cpython-310.pyc +0 -0
  129. clarifai/client/__pycache__/user.cpython-311.pyc +0 -0
  130. clarifai/client/__pycache__/workflow.cpython-310.pyc +0 -0
  131. clarifai/client/__pycache__/workflow.cpython-311.pyc +0 -0
  132. clarifai/client/auth/__pycache__/__init__.cpython-310.pyc +0 -0
  133. clarifai/client/auth/__pycache__/__init__.cpython-311.pyc +0 -0
  134. clarifai/client/auth/__pycache__/helper.cpython-310.pyc +0 -0
  135. clarifai/client/auth/__pycache__/helper.cpython-311.pyc +0 -0
  136. clarifai/client/auth/__pycache__/register.cpython-310.pyc +0 -0
  137. clarifai/client/auth/__pycache__/register.cpython-311.pyc +0 -0
  138. clarifai/client/auth/__pycache__/stub.cpython-310.pyc +0 -0
  139. clarifai/client/auth/__pycache__/stub.cpython-311.pyc +0 -0
  140. clarifai/client/cli/__init__.py +0 -0
  141. clarifai/client/cli/__pycache__/__init__.cpython-310.pyc +0 -0
  142. clarifai/client/cli/__pycache__/base_cli.cpython-310.pyc +0 -0
  143. clarifai/client/cli/__pycache__/model_cli.cpython-310.pyc +0 -0
  144. clarifai/client/cli/base_cli.py +0 -88
  145. clarifai/client/cli/model_cli.py +0 -29
  146. clarifai/constants/__pycache__/base.cpython-310.pyc +0 -0
  147. clarifai/constants/__pycache__/base.cpython-311.pyc +0 -0
  148. clarifai/constants/__pycache__/dataset.cpython-310.pyc +0 -0
  149. clarifai/constants/__pycache__/dataset.cpython-311.pyc +0 -0
  150. clarifai/constants/__pycache__/input.cpython-310.pyc +0 -0
  151. clarifai/constants/__pycache__/input.cpython-311.pyc +0 -0
  152. clarifai/constants/__pycache__/model.cpython-310.pyc +0 -0
  153. clarifai/constants/__pycache__/model.cpython-311.pyc +0 -0
  154. clarifai/constants/__pycache__/rag.cpython-310.pyc +0 -0
  155. clarifai/constants/__pycache__/rag.cpython-311.pyc +0 -0
  156. clarifai/constants/__pycache__/search.cpython-310.pyc +0 -0
  157. clarifai/constants/__pycache__/search.cpython-311.pyc +0 -0
  158. clarifai/constants/__pycache__/workflow.cpython-310.pyc +0 -0
  159. clarifai/constants/__pycache__/workflow.cpython-311.pyc +0 -0
  160. clarifai/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
  161. clarifai/datasets/__pycache__/__init__.cpython-311.pyc +0 -0
  162. clarifai/datasets/__pycache__/__init__.cpython-39.pyc +0 -0
  163. clarifai/datasets/export/__pycache__/__init__.cpython-310.pyc +0 -0
  164. clarifai/datasets/export/__pycache__/__init__.cpython-311.pyc +0 -0
  165. clarifai/datasets/export/__pycache__/__init__.cpython-39.pyc +0 -0
  166. clarifai/datasets/export/__pycache__/inputs_annotations.cpython-310.pyc +0 -0
  167. clarifai/datasets/export/__pycache__/inputs_annotations.cpython-311.pyc +0 -0
  168. clarifai/datasets/upload/__pycache__/__init__.cpython-310.pyc +0 -0
  169. clarifai/datasets/upload/__pycache__/__init__.cpython-311.pyc +0 -0
  170. clarifai/datasets/upload/__pycache__/__init__.cpython-39.pyc +0 -0
  171. clarifai/datasets/upload/__pycache__/base.cpython-310.pyc +0 -0
  172. clarifai/datasets/upload/__pycache__/base.cpython-311.pyc +0 -0
  173. clarifai/datasets/upload/__pycache__/features.cpython-310.pyc +0 -0
  174. clarifai/datasets/upload/__pycache__/features.cpython-311.pyc +0 -0
  175. clarifai/datasets/upload/__pycache__/image.cpython-310.pyc +0 -0
  176. clarifai/datasets/upload/__pycache__/image.cpython-311.pyc +0 -0
  177. clarifai/datasets/upload/__pycache__/multimodal.cpython-310.pyc +0 -0
  178. clarifai/datasets/upload/__pycache__/multimodal.cpython-311.pyc +0 -0
  179. clarifai/datasets/upload/__pycache__/text.cpython-310.pyc +0 -0
  180. clarifai/datasets/upload/__pycache__/text.cpython-311.pyc +0 -0
  181. clarifai/datasets/upload/__pycache__/utils.cpython-310.pyc +0 -0
  182. clarifai/datasets/upload/__pycache__/utils.cpython-311.pyc +0 -0
  183. clarifai/datasets/upload/loaders/__pycache__/__init__.cpython-311.pyc +0 -0
  184. clarifai/datasets/upload/loaders/__pycache__/__init__.cpython-39.pyc +0 -0
  185. clarifai/datasets/upload/loaders/__pycache__/coco_detection.cpython-311.pyc +0 -0
  186. clarifai/datasets/upload/loaders/__pycache__/imagenet_classification.cpython-311.pyc +0 -0
  187. clarifai/models/__pycache__/__init__.cpython-39.pyc +0 -0
  188. clarifai/modules/__pycache__/__init__.cpython-39.pyc +0 -0
  189. clarifai/rag/__pycache__/__init__.cpython-310.pyc +0 -0
  190. clarifai/rag/__pycache__/__init__.cpython-311.pyc +0 -0
  191. clarifai/rag/__pycache__/__init__.cpython-39.pyc +0 -0
  192. clarifai/rag/__pycache__/rag.cpython-310.pyc +0 -0
  193. clarifai/rag/__pycache__/rag.cpython-311.pyc +0 -0
  194. clarifai/rag/__pycache__/rag.cpython-39.pyc +0 -0
  195. clarifai/rag/__pycache__/utils.cpython-310.pyc +0 -0
  196. clarifai/rag/__pycache__/utils.cpython-311.pyc +0 -0
  197. clarifai/runners/__pycache__/__init__.cpython-310.pyc +0 -0
  198. clarifai/runners/__pycache__/__init__.cpython-311.pyc +0 -0
  199. clarifai/runners/__pycache__/__init__.cpython-39.pyc +0 -0
  200. clarifai/runners/dockerfile_template/Dockerfile.cpu.template +0 -31
  201. clarifai/runners/dockerfile_template/Dockerfile.cuda.template +0 -42
  202. clarifai/runners/dockerfile_template/Dockerfile.nim +0 -71
  203. clarifai/runners/models/__pycache__/__init__.cpython-310.pyc +0 -0
  204. clarifai/runners/models/__pycache__/__init__.cpython-311.pyc +0 -0
  205. clarifai/runners/models/__pycache__/__init__.cpython-39.pyc +0 -0
  206. clarifai/runners/models/__pycache__/base_typed_model.cpython-310.pyc +0 -0
  207. clarifai/runners/models/__pycache__/base_typed_model.cpython-311.pyc +0 -0
  208. clarifai/runners/models/__pycache__/base_typed_model.cpython-39.pyc +0 -0
  209. clarifai/runners/models/__pycache__/model_builder.cpython-311.pyc +0 -0
  210. clarifai/runners/models/__pycache__/model_class.cpython-310.pyc +0 -0
  211. clarifai/runners/models/__pycache__/model_class.cpython-311.pyc +0 -0
  212. clarifai/runners/models/__pycache__/model_run_locally.cpython-310-pytest-7.1.2.pyc +0 -0
  213. clarifai/runners/models/__pycache__/model_run_locally.cpython-310.pyc +0 -0
  214. clarifai/runners/models/__pycache__/model_run_locally.cpython-311.pyc +0 -0
  215. clarifai/runners/models/__pycache__/model_runner.cpython-310.pyc +0 -0
  216. clarifai/runners/models/__pycache__/model_runner.cpython-311.pyc +0 -0
  217. clarifai/runners/models/__pycache__/model_upload.cpython-310.pyc +0 -0
  218. clarifai/runners/models/base_typed_model.py +0 -238
  219. clarifai/runners/models/model_class_refract.py +0 -80
  220. clarifai/runners/models/model_upload.py +0 -607
  221. clarifai/runners/models/temp.py +0 -25
  222. clarifai/runners/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  223. clarifai/runners/utils/__pycache__/__init__.cpython-311.pyc +0 -0
  224. clarifai/runners/utils/__pycache__/__init__.cpython-38.pyc +0 -0
  225. clarifai/runners/utils/__pycache__/__init__.cpython-39.pyc +0 -0
  226. clarifai/runners/utils/__pycache__/buffered_stream.cpython-310.pyc +0 -0
  227. clarifai/runners/utils/__pycache__/buffered_stream.cpython-38.pyc +0 -0
  228. clarifai/runners/utils/__pycache__/buffered_stream.cpython-39.pyc +0 -0
  229. clarifai/runners/utils/__pycache__/const.cpython-310.pyc +0 -0
  230. clarifai/runners/utils/__pycache__/const.cpython-311.pyc +0 -0
  231. clarifai/runners/utils/__pycache__/constants.cpython-310.pyc +0 -0
  232. clarifai/runners/utils/__pycache__/constants.cpython-38.pyc +0 -0
  233. clarifai/runners/utils/__pycache__/constants.cpython-39.pyc +0 -0
  234. clarifai/runners/utils/__pycache__/data_handler.cpython-310.pyc +0 -0
  235. clarifai/runners/utils/__pycache__/data_handler.cpython-311.pyc +0 -0
  236. clarifai/runners/utils/__pycache__/data_handler.cpython-38.pyc +0 -0
  237. clarifai/runners/utils/__pycache__/data_handler.cpython-39.pyc +0 -0
  238. clarifai/runners/utils/__pycache__/data_utils.cpython-310.pyc +0 -0
  239. clarifai/runners/utils/__pycache__/data_utils.cpython-311.pyc +0 -0
  240. clarifai/runners/utils/__pycache__/data_utils.cpython-38.pyc +0 -0
  241. clarifai/runners/utils/__pycache__/data_utils.cpython-39.pyc +0 -0
  242. clarifai/runners/utils/__pycache__/grpc_server.cpython-310.pyc +0 -0
  243. clarifai/runners/utils/__pycache__/grpc_server.cpython-38.pyc +0 -0
  244. clarifai/runners/utils/__pycache__/grpc_server.cpython-39.pyc +0 -0
  245. clarifai/runners/utils/__pycache__/health.cpython-310.pyc +0 -0
  246. clarifai/runners/utils/__pycache__/health.cpython-38.pyc +0 -0
  247. clarifai/runners/utils/__pycache__/health.cpython-39.pyc +0 -0
  248. clarifai/runners/utils/__pycache__/loader.cpython-310.pyc +0 -0
  249. clarifai/runners/utils/__pycache__/loader.cpython-311.pyc +0 -0
  250. clarifai/runners/utils/__pycache__/logging.cpython-310.pyc +0 -0
  251. clarifai/runners/utils/__pycache__/logging.cpython-38.pyc +0 -0
  252. clarifai/runners/utils/__pycache__/logging.cpython-39.pyc +0 -0
  253. clarifai/runners/utils/__pycache__/stream_source.cpython-310.pyc +0 -0
  254. clarifai/runners/utils/__pycache__/stream_source.cpython-39.pyc +0 -0
  255. clarifai/runners/utils/__pycache__/url_fetcher.cpython-310.pyc +0 -0
  256. clarifai/runners/utils/__pycache__/url_fetcher.cpython-311.pyc +0 -0
  257. clarifai/runners/utils/__pycache__/url_fetcher.cpython-38.pyc +0 -0
  258. clarifai/runners/utils/__pycache__/url_fetcher.cpython-39.pyc +0 -0
  259. clarifai/runners/utils/data_handler.py +0 -231
  260. clarifai/runners/utils/data_handler_refract.py +0 -213
  261. clarifai/runners/utils/data_types.py +0 -469
  262. clarifai/runners/utils/logger.py +0 -0
  263. clarifai/runners/utils/openai_format.py +0 -87
  264. clarifai/schema/__pycache__/search.cpython-310.pyc +0 -0
  265. clarifai/schema/__pycache__/search.cpython-311.pyc +0 -0
  266. clarifai/urls/__pycache__/helper.cpython-310.pyc +0 -0
  267. clarifai/urls/__pycache__/helper.cpython-311.pyc +0 -0
  268. clarifai/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  269. clarifai/utils/__pycache__/__init__.cpython-311.pyc +0 -0
  270. clarifai/utils/__pycache__/__init__.cpython-39.pyc +0 -0
  271. clarifai/utils/__pycache__/cli.cpython-310.pyc +0 -0
  272. clarifai/utils/__pycache__/cli.cpython-311.pyc +0 -0
  273. clarifai/utils/__pycache__/config.cpython-311.pyc +0 -0
  274. clarifai/utils/__pycache__/constants.cpython-310.pyc +0 -0
  275. clarifai/utils/__pycache__/constants.cpython-311.pyc +0 -0
  276. clarifai/utils/__pycache__/logging.cpython-310.pyc +0 -0
  277. clarifai/utils/__pycache__/logging.cpython-311.pyc +0 -0
  278. clarifai/utils/__pycache__/misc.cpython-310.pyc +0 -0
  279. clarifai/utils/__pycache__/misc.cpython-311.pyc +0 -0
  280. clarifai/utils/__pycache__/model_train.cpython-310.pyc +0 -0
  281. clarifai/utils/__pycache__/model_train.cpython-311.pyc +0 -0
  282. clarifai/utils/__pycache__/protobuf.cpython-311.pyc +0 -0
  283. clarifai/utils/evaluation/__pycache__/__init__.cpython-311.pyc +0 -0
  284. clarifai/utils/evaluation/__pycache__/__init__.cpython-39.pyc +0 -0
  285. clarifai/utils/evaluation/__pycache__/helpers.cpython-311.pyc +0 -0
  286. clarifai/utils/evaluation/__pycache__/main.cpython-311.pyc +0 -0
  287. clarifai/utils/evaluation/__pycache__/main.cpython-39.pyc +0 -0
  288. clarifai/workflows/__pycache__/__init__.cpython-310.pyc +0 -0
  289. clarifai/workflows/__pycache__/__init__.cpython-311.pyc +0 -0
  290. clarifai/workflows/__pycache__/__init__.cpython-39.pyc +0 -0
  291. clarifai/workflows/__pycache__/export.cpython-310.pyc +0 -0
  292. clarifai/workflows/__pycache__/export.cpython-311.pyc +0 -0
  293. clarifai/workflows/__pycache__/utils.cpython-310.pyc +0 -0
  294. clarifai/workflows/__pycache__/utils.cpython-311.pyc +0 -0
  295. clarifai/workflows/__pycache__/validate.cpython-310.pyc +0 -0
  296. clarifai/workflows/__pycache__/validate.cpython-311.pyc +0 -0
  297. clarifai-11.3.0rc2.dist-info/RECORD +0 -322
  298. {clarifai-11.3.0rc2.dist-info → clarifai-11.4.0.dist-info}/entry_points.txt +0 -0
  299. {clarifai-11.3.0rc2.dist-info → clarifai-11.4.0.dist-info/licenses}/LICENSE +0 -0
  300. {clarifai-11.3.0rc2.dist-info → clarifai-11.4.0.dist-info}/top_level.txt +0 -0
@@ -3,6 +3,7 @@ import importlib
3
3
  import inspect
4
4
  import os
5
5
  import re
6
+ import shutil
6
7
  import sys
7
8
  import tarfile
8
9
  import time
@@ -19,9 +20,15 @@ from rich.markup import escape
19
20
  from clarifai.client.base import BaseClient
20
21
  from clarifai.runners.models.model_class import ModelClass
21
22
  from clarifai.runners.utils.const import (
22
- AVAILABLE_PYTHON_IMAGES, AVAILABLE_TORCH_IMAGES, CONCEPTS_REQUIRED_MODEL_TYPE,
23
- DEFAULT_DOWNLOAD_CHECKPOINT_WHEN, DEFAULT_PYTHON_VERSION, DEFAULT_RUNTIME_DOWNLOAD_PATH,
24
- PYTHON_BASE_IMAGE, TORCH_BASE_IMAGE)
23
+ AVAILABLE_PYTHON_IMAGES,
24
+ AVAILABLE_TORCH_IMAGES,
25
+ CONCEPTS_REQUIRED_MODEL_TYPE,
26
+ DEFAULT_DOWNLOAD_CHECKPOINT_WHEN,
27
+ DEFAULT_PYTHON_VERSION,
28
+ DEFAULT_RUNTIME_DOWNLOAD_PATH,
29
+ PYTHON_BASE_IMAGE,
30
+ TORCH_BASE_IMAGE,
31
+ )
25
32
  from clarifai.runners.utils.loader import HuggingFaceLoader
26
33
  from clarifai.runners.utils.method_signatures import signatures_to_yaml
27
34
  from clarifai.urls.helper import ClarifaiUrlHelper
@@ -37,839 +44,936 @@ dependencies = [
37
44
 
38
45
 
39
46
  def _clear_line(n: int = 1) -> None:
40
- LINE_UP = '\033[1A' # Move cursor up one line
41
- LINE_CLEAR = '\x1b[2K' # Clear the entire line
42
- for _ in range(n):
43
- print(LINE_UP, end=LINE_CLEAR, flush=True)
47
+ LINE_UP = '\033[1A' # Move cursor up one line
48
+ LINE_CLEAR = '\x1b[2K' # Clear the entire line
49
+ for _ in range(n):
50
+ print(LINE_UP, end=LINE_CLEAR, flush=True)
44
51
 
45
52
 
46
53
  def is_related(object_class, main_class):
47
- # Check if the object_class is a subclass of main_class
48
- if issubclass(object_class, main_class):
49
- return True
54
+ # Check if the object_class is a subclass of main_class
55
+ if issubclass(object_class, main_class):
56
+ return True
50
57
 
51
- # Check if the object_class is a subclass of any of the parent classes of main_class
52
- parent_classes = object_class.__bases__
53
- for parent in parent_classes:
54
- if main_class in parent.__bases__:
55
- return True
56
- return False
58
+ # Check if the object_class is a subclass of any of the parent classes of main_class
59
+ parent_classes = object_class.__bases__
60
+ for parent in parent_classes:
61
+ if main_class in parent.__bases__:
62
+ return True
63
+ return False
57
64
 
58
65
 
59
66
  class ModelBuilder:
60
- DEFAULT_CHECKPOINT_SIZE = 50 * 1024**3 # 50 GiB
67
+ DEFAULT_CHECKPOINT_SIZE = 50 * 1024**3 # 50 GiB
68
+
69
+ def __init__(self, folder: str, validate_api_ids: bool = True, download_validation_only=False):
70
+ """
71
+ :param folder: The folder containing the model.py, config.yaml, requirements.txt and
72
+ checkpoints.
73
+ :param validate_api_ids: Whether to validate the user_id and app_id in the config file. TODO(zeiler):
74
+ deprecate in favor of download_validation_only.
75
+ :param download_validation_only: Whether to skip the API config validation. Set to True if
76
+ just downloading a checkpoint.
77
+ """
78
+ self._client = None
79
+ if not validate_api_ids: # for backwards compatibility
80
+ download_validation_only = True
81
+ self.download_validation_only = download_validation_only
82
+ self.folder = self._validate_folder(folder)
83
+ self.config = self._load_config(os.path.join(self.folder, 'config.yaml'))
84
+ self._validate_config()
85
+ self.model_proto = self._get_model_proto()
86
+ self.model_id = self.model_proto.id
87
+ self.model_version_id = None
88
+ self.inference_compute_info = self._get_inference_compute_info()
89
+ self.is_v3 = True # Do model build for v3
90
+
91
+ def create_model_instance(self, load_model=True, mocking=False):
92
+ """
93
+ Create an instance of the model class, as specified in the config file.
94
+ """
95
+ model_class = self.load_model_class(mocking=mocking)
96
+
97
+ # initialize the model
98
+ model = model_class()
99
+ if load_model:
100
+ model.load_model()
101
+ return model
102
+
103
+ def load_model_class(self, mocking=False):
104
+ """
105
+ Import the model class from the model.py file, dynamically handling missing dependencies
106
+ """
107
+ # look for default model.py file location
108
+ for loc in ["model.py", "1/model.py"]:
109
+ model_file = os.path.join(self.folder, loc)
110
+ if os.path.exists(model_file):
111
+ break
112
+ if not os.path.exists(model_file):
113
+ raise Exception("Model file not found.")
114
+
115
+ module_name = os.path.basename(model_file).replace(".py", "")
116
+
117
+ spec = importlib.util.spec_from_file_location(module_name, model_file)
118
+ module = importlib.util.module_from_spec(spec)
119
+ sys.modules[module_name] = module
120
+
121
+ original_import = builtins.__import__
122
+
123
+ def custom_import(name, globals=None, locals=None, fromlist=(), level=0):
124
+ # Allow standard libraries and clarifai
125
+ if self._is_standard_or_clarifai(name):
126
+ return original_import(name, globals, locals, fromlist, level)
127
+
128
+ # Mock all third-party imports to avoid ImportErrors or other issues
129
+ return MagicMock()
130
+
131
+ if mocking:
132
+ # Replace the built-in __import__ function with our custom one
133
+ builtins.__import__ = custom_import
61
134
 
62
- def __init__(self, folder: str, validate_api_ids: bool = True, download_validation_only=False):
63
- """
64
- :param folder: The folder containing the model.py, config.yaml, requirements.txt and
65
- checkpoints.
66
- :param validate_api_ids: Whether to validate the user_id and app_id in the config file. TODO(zeiler):
67
- deprecate in favor of download_validation_only.
68
- :param download_validation_only: Whether to skip the API config validation. Set to True if
69
- just downloading a checkpoint.
70
- """
71
- self._client = None
72
- if not validate_api_ids: # for backwards compatibility
73
- download_validation_only = True
74
- self.download_validation_only = download_validation_only
75
- self.folder = self._validate_folder(folder)
76
- self.config = self._load_config(os.path.join(self.folder, 'config.yaml'))
77
- self._validate_config()
78
- self.model_proto = self._get_model_proto()
79
- self.model_id = self.model_proto.id
80
- self.model_version_id = None
81
- self.inference_compute_info = self._get_inference_compute_info()
82
- self.is_v3 = True # Do model build for v3
83
-
84
- def create_model_instance(self, load_model=True, mocking=False):
85
- """
86
- Create an instance of the model class, as specified in the config file.
87
- """
88
- model_class = self.load_model_class(mocking=mocking)
135
+ try:
136
+ spec.loader.exec_module(module)
137
+ except Exception as e:
138
+ logger.error(f"Error loading model.py: {e}")
139
+ raise
140
+ finally:
141
+ # Restore the original __import__ function
142
+ builtins.__import__ = original_import
143
+
144
+ # Find all classes in the model.py file that are subclasses of ModelClass
145
+ classes = [
146
+ cls
147
+ for _, cls in inspect.getmembers(module, inspect.isclass)
148
+ if is_related(cls, ModelClass) and cls.__module__ == module.__name__
149
+ ]
150
+ # Ensure there is exactly one subclass of BaseRunner in the model.py file
151
+ if len(classes) != 1:
152
+ # check for old inheritence structure, ModelRunner used to be a ModelClass
153
+ runner_classes = [
154
+ cls
155
+ for _, cls in inspect.getmembers(module, inspect.isclass)
156
+ if cls.__module__ == module.__name__
157
+ and any(c.__name__ == 'ModelRunner' for c in cls.__bases__)
158
+ ]
159
+ if runner_classes and len(runner_classes) == 1:
160
+ raise Exception(
161
+ f'Could not determine model class.'
162
+ f' Models should now inherit from {ModelClass.__module__}.ModelClass, not ModelRunner.'
163
+ f' Please update your model "{runner_classes[0].__name__}" to inherit from ModelClass.'
164
+ )
165
+ raise Exception(
166
+ "Could not determine model class. There should be exactly one model inheriting from ModelClass defined in the model.py"
167
+ )
168
+ model_class = classes[0]
169
+ return model_class
170
+
171
+ def _is_standard_or_clarifai(self, name):
172
+ """Check if import is from standard library or clarifai"""
173
+ if name.startswith("clarifai"):
174
+ return True
175
+
176
+ # Handle Python <3.10 compatibility
177
+ stdlib_names = getattr(sys, "stdlib_module_names", sys.builtin_module_names)
178
+ if name in stdlib_names:
179
+ return True
180
+
181
+ # Handle submodules (e.g., os.path)
182
+ parts = name.split(".")
183
+ for i in range(1, len(parts)):
184
+ if ".".join(parts[:i]) in stdlib_names:
185
+ return True
186
+ return False
187
+
188
+ def _validate_folder(self, folder):
189
+ if folder == ".":
190
+ folder = "" # will getcwd() next which ends with /
191
+ if not os.path.isabs(folder):
192
+ folder = os.path.join(os.getcwd(), folder)
193
+ logger.debug(f"Validating folder: {folder}")
194
+ if not os.path.exists(folder):
195
+ raise FileNotFoundError(
196
+ f"Folder {folder} not found, please provide a valid folder path"
197
+ )
198
+ files = os.listdir(folder)
199
+ assert "config.yaml" in files, "config.yaml not found in the folder"
200
+ # If just downloading we don't need requirements.txt or the python code, we do need the
201
+ # 1/ folder to put 1/checkpoints into.
202
+ assert "1" in files, "Subfolder '1' not found in the folder"
203
+ if not self.download_validation_only:
204
+ assert "requirements.txt" in files, "requirements.txt not found in the folder"
205
+ subfolder_files = os.listdir(os.path.join(folder, '1'))
206
+ assert 'model.py' in subfolder_files, "model.py not found in the folder"
207
+ return folder
208
+
209
+ @staticmethod
210
+ def _load_config(config_file: str):
211
+ with open(config_file, 'r') as file:
212
+ config = yaml.safe_load(file)
213
+ return config
214
+
215
+ @staticmethod
216
+ def _backup_config(config_file: str):
217
+ if not os.path.exists(config_file):
218
+ return
219
+ backup_file = config_file + ".bak"
220
+ if os.path.exists(backup_file):
221
+ raise FileExistsError(
222
+ f"Backup file {backup_file} already exists. Please remove it before proceeding."
223
+ )
224
+ shutil.copy(config_file, backup_file)
225
+
226
+ @staticmethod
227
+ def _save_config(config_file: str, config: dict):
228
+ with open(config_file, 'w') as f:
229
+ yaml.safe_dump(config, f)
230
+
231
+ def _validate_config_checkpoints(self):
232
+ """
233
+ Validates the checkpoints section in the config file.
234
+ return loader_type, repo_id, hf_token, when, allowed_file_patterns, ignore_file_patterns
235
+ :return: loader_type the type of loader or None if no checkpoints.
236
+ :return: repo_id location of checkpoint.
237
+ :return: hf_token token to access checkpoint.
238
+ :return: when one of ['upload', 'build', 'runtime'] to download checkpoint
239
+ :return: allowed_file_patterns patterns to allow in downloaded checkpoint
240
+ :return: ignore_file_patterns patterns to ignore in downloaded checkpoint
241
+ """
242
+ if "checkpoints" not in self.config:
243
+ return None, None, None, DEFAULT_DOWNLOAD_CHECKPOINT_WHEN, None, None
244
+ assert "type" in self.config.get("checkpoints"), (
245
+ "No loader type specified in the config file"
246
+ )
247
+ loader_type = self.config.get("checkpoints").get("type")
248
+ if not loader_type:
249
+ logger.info("No loader type specified in the config file for checkpoints")
250
+ return None, None, None, DEFAULT_DOWNLOAD_CHECKPOINT_WHEN, None, None
251
+ checkpoints = self.config.get("checkpoints")
252
+ if 'when' not in checkpoints:
253
+ logger.warn(
254
+ f"No 'when' specified in the config file for checkpoints, defaulting to download at {DEFAULT_DOWNLOAD_CHECKPOINT_WHEN}"
255
+ )
256
+ when = checkpoints.get("when", DEFAULT_DOWNLOAD_CHECKPOINT_WHEN)
257
+ assert when in [
258
+ "upload",
259
+ "build",
260
+ "runtime",
261
+ ], (
262
+ "Invalid value for when in the checkpoint loader when, needs to be one of ['upload', 'build', 'runtime']"
263
+ )
264
+ assert loader_type == "huggingface", "Only huggingface loader supported for now"
265
+ if loader_type == "huggingface":
266
+ assert "repo_id" in self.config.get("checkpoints"), (
267
+ "No repo_id specified in the config file"
268
+ )
269
+ repo_id = self.config.get("checkpoints").get("repo_id")
89
270
 
90
- # initialize the model
91
- model = model_class()
92
- if load_model:
93
- model.load_model()
94
- return model
271
+ # get from config.yaml otherwise fall back to HF_TOKEN env var.
272
+ hf_token = self.config.get("checkpoints").get(
273
+ "hf_token", os.environ.get("HF_TOKEN", None)
274
+ )
95
275
 
96
- def load_model_class(self, mocking=False):
97
- """
98
- Import the model class from the model.py file, dynamically handling missing dependencies
99
- """
100
- # look for default model.py file location
101
- for loc in ["model.py", "1/model.py"]:
102
- model_file = os.path.join(self.folder, loc)
103
- if os.path.exists(model_file):
104
- break
105
- if not os.path.exists(model_file):
106
- raise Exception("Model file not found.")
107
-
108
- module_name = os.path.basename(model_file).replace(".py", "")
109
-
110
- spec = importlib.util.spec_from_file_location(module_name, model_file)
111
- module = importlib.util.module_from_spec(spec)
112
- sys.modules[module_name] = module
113
-
114
- original_import = builtins.__import__
115
-
116
- def custom_import(name, globals=None, locals=None, fromlist=(), level=0):
117
-
118
- # Allow standard libraries and clarifai
119
- if self._is_standard_or_clarifai(name):
120
- return original_import(name, globals, locals, fromlist, level)
121
-
122
- # Mock all third-party imports to avoid ImportErrors or other issues
123
- return MagicMock()
124
-
125
- if mocking:
126
- # Replace the built-in __import__ function with our custom one
127
- builtins.__import__ = custom_import
128
-
129
- try:
130
- spec.loader.exec_module(module)
131
- except Exception as e:
132
- logger.error(f"Error loading model.py: {e}")
133
- raise
134
- finally:
135
- # Restore the original __import__ function
136
- builtins.__import__ = original_import
137
-
138
- # Find all classes in the model.py file that are subclasses of ModelClass
139
- classes = [
140
- cls for _, cls in inspect.getmembers(module, inspect.isclass)
141
- if is_related(cls, ModelClass) and cls.__module__ == module.__name__
142
- ]
143
- # Ensure there is exactly one subclass of BaseRunner in the model.py file
144
- if len(classes) != 1:
145
- # check for old inheritence structure, ModelRunner used to be a ModelClass
146
- runner_classes = [
147
- cls for _, cls in inspect.getmembers(module, inspect.isclass)
148
- if cls.__module__ == module.__name__ and any(c.__name__ == 'ModelRunner'
149
- for c in cls.__bases__)
150
- ]
151
- if runner_classes and len(runner_classes) == 1:
152
- raise Exception(
153
- f'Could not determine model class.'
154
- f' Models should now inherit from {ModelClass.__module__}.ModelClass, not ModelRunner.'
155
- f' Please update your model "{runner_classes[0].__name__}" to inherit from ModelClass.'
276
+ allowed_file_patterns = self.config.get("checkpoints").get(
277
+ 'allowed_file_patterns', None
278
+ )
279
+ if isinstance(allowed_file_patterns, str):
280
+ allowed_file_patterns = [allowed_file_patterns]
281
+ ignore_file_patterns = self.config.get("checkpoints").get('ignore_file_patterns', None)
282
+ if isinstance(ignore_file_patterns, str):
283
+ ignore_file_patterns = [ignore_file_patterns]
284
+ return (
285
+ loader_type,
286
+ repo_id,
287
+ hf_token,
288
+ when,
289
+ allowed_file_patterns,
290
+ ignore_file_patterns,
291
+ )
292
+
293
+ def _check_app_exists(self):
294
+ resp = self.client.STUB.GetApp(
295
+ service_pb2.GetAppRequest(user_app_id=self.client.user_app_id)
156
296
  )
157
- raise Exception(
158
- "Could not determine model class. There should be exactly one model inheriting from ModelClass defined in the model.py"
159
- )
160
- model_class = classes[0]
161
- return model_class
162
-
163
- def _is_standard_or_clarifai(self, name):
164
- """Check if import is from standard library or clarifai"""
165
- if name.startswith("clarifai"):
166
- return True
167
-
168
- # Handle Python <3.10 compatibility
169
- stdlib_names = getattr(sys, "stdlib_module_names", sys.builtin_module_names)
170
- if name in stdlib_names:
171
- return True
172
-
173
- # Handle submodules (e.g., os.path)
174
- parts = name.split(".")
175
- for i in range(1, len(parts)):
176
- if ".".join(parts[:i]) in stdlib_names:
177
- return True
178
- return False
297
+ if resp.status.code == status_code_pb2.SUCCESS:
298
+ return True
299
+ if resp.status.code == status_code_pb2.CONN_KEY_INVALID:
300
+ logger.error(
301
+ f"Invalid PAT provided for user {self.client.user_app_id.user_id}. Please check your PAT and try again."
302
+ )
303
+ return False
304
+ logger.error(
305
+ f"Error checking API {self._base_api} for user app {self.client.user_app_id.user_id}/{self.client.user_app_id.app_id}. Error code: {resp.status.code}"
306
+ )
307
+ logger.error(
308
+ f"App {self.client.user_app_id.app_id} not found for user {self.client.user_app_id.user_id}. Please create the app first and try again."
309
+ )
310
+ return False
179
311
 
180
- def _validate_folder(self, folder):
181
- if folder == ".":
182
- folder = "" # will getcwd() next which ends with /
183
- if not os.path.isabs(folder):
184
- folder = os.path.join(os.getcwd(), folder)
185
- logger.info(f"Validating folder: {folder}")
186
- if not os.path.exists(folder):
187
- raise FileNotFoundError(f"Folder {folder} not found, please provide a valid folder path")
188
- files = os.listdir(folder)
189
- assert "config.yaml" in files, "config.yaml not found in the folder"
190
- # If just downloading we don't need requirements.txt or the python code, we do need the
191
- # 1/ folder to put 1/checkpoints into.
192
- assert "1" in files, "Subfolder '1' not found in the folder"
193
- if not self.download_validation_only:
194
- assert "requirements.txt" in files, "requirements.txt not found in the folder"
195
- subfolder_files = os.listdir(os.path.join(folder, '1'))
196
- assert 'model.py' in subfolder_files, "model.py not found in the folder"
197
- return folder
198
-
199
- @staticmethod
200
- def _load_config(config_file: str):
201
- with open(config_file, 'r') as file:
202
- config = yaml.safe_load(file)
203
- return config
204
-
205
- def _validate_config_checkpoints(self):
206
- """
207
- Validates the checkpoints section in the config file.
208
- return loader_type, repo_id, hf_token, when, allowed_file_patterns, ignore_file_patterns
209
- :return: loader_type the type of loader or None if no checkpoints.
210
- :return: repo_id location of checkpoint.
211
- :return: hf_token token to access checkpoint.
212
- :return: when one of ['upload', 'build', 'runtime'] to download checkpoint
213
- :return: allowed_file_patterns patterns to allow in downloaded checkpoint
214
- :return: ignore_file_patterns patterns to ignore in downloaded checkpoint
215
- """
216
- if "checkpoints" not in self.config:
217
- return None, None, None, DEFAULT_DOWNLOAD_CHECKPOINT_WHEN, None, None
218
- assert "type" in self.config.get("checkpoints"), "No loader type specified in the config file"
219
- loader_type = self.config.get("checkpoints").get("type")
220
- if not loader_type:
221
- logger.info("No loader type specified in the config file for checkpoints")
222
- return None, None, None, DEFAULT_DOWNLOAD_CHECKPOINT_WHEN, None, None
223
- checkpoints = self.config.get("checkpoints")
224
- if 'when' not in checkpoints:
225
- logger.warn(
226
- f"No 'when' specified in the config file for checkpoints, defaulting to download at {DEFAULT_DOWNLOAD_CHECKPOINT_WHEN}"
227
- )
228
- when = checkpoints.get("when", DEFAULT_DOWNLOAD_CHECKPOINT_WHEN)
229
- assert when in [
230
- "upload",
231
- "build",
232
- "runtime",
233
- ], "Invalid value for when in the checkpoint loader when, needs to be one of ['upload', 'build', 'runtime']"
234
- assert loader_type == "huggingface", "Only huggingface loader supported for now"
235
- if loader_type == "huggingface":
236
- assert "repo_id" in self.config.get("checkpoints"), "No repo_id specified in the config file"
237
- repo_id = self.config.get("checkpoints").get("repo_id")
238
-
239
- # get from config.yaml otherwise fall back to HF_TOKEN env var.
240
- hf_token = self.config.get("checkpoints").get("hf_token", os.environ.get("HF_TOKEN", None))
241
-
242
- allowed_file_patterns = self.config.get("checkpoints").get('allowed_file_patterns', None)
243
- if isinstance(allowed_file_patterns, str):
244
- allowed_file_patterns = [allowed_file_patterns]
245
- ignore_file_patterns = self.config.get("checkpoints").get('ignore_file_patterns', None)
246
- if isinstance(ignore_file_patterns, str):
247
- ignore_file_patterns = [ignore_file_patterns]
248
- return loader_type, repo_id, hf_token, when, allowed_file_patterns, ignore_file_patterns
249
-
250
- def _check_app_exists(self):
251
- resp = self.client.STUB.GetApp(service_pb2.GetAppRequest(user_app_id=self.client.user_app_id))
252
- if resp.status.code == status_code_pb2.SUCCESS:
253
- return True
254
- if resp.status.code == status_code_pb2.CONN_KEY_INVALID:
255
- logger.error(
256
- f"Invalid PAT provided for user {self.client.user_app_id.user_id}. Please check your PAT and try again."
257
- )
258
- return False
259
- logger.error(
260
- f"Error checking API {self._base_api} for user app {self.client.user_app_id.user_id}/{self.client.user_app_id.app_id}. Error code: {resp.status.code}"
261
- )
262
- logger.error(
263
- f"App {self.client.user_app_id.app_id} not found for user {self.client.user_app_id.user_id}. Please create the app first and try again."
264
- )
265
- return False
312
+ def _validate_config_model(self):
313
+ assert "model" in self.config, "model section not found in the config file"
314
+ model = self.config.get('model')
315
+ assert "user_id" in model, "user_id not found in the config file"
316
+ assert "app_id" in model, "app_id not found in the config file"
317
+ assert "model_type_id" in model, "model_type_id not found in the config file"
318
+ assert "id" in model, "model_id not found in the config file"
319
+ if '.' in model.get('id'):
320
+ logger.error(
321
+ "Model ID cannot contain '.', please remove it from the model_id in the config file"
322
+ )
323
+ sys.exit(1)
324
+
325
+ assert model.get('user_id') != "", "user_id cannot be empty in the config file"
326
+ assert model.get('app_id') != "", "app_id cannot be empty in the config file"
327
+ assert model.get('model_type_id') != "", "model_type_id cannot be empty in the config file"
328
+ assert model.get('id') != "", "model_id cannot be empty in the config file"
329
+
330
+ if not self._check_app_exists():
331
+ sys.exit(1)
332
+
333
+ @staticmethod
334
+ def _set_local_dev_model(config, user_id, app_id, model_id, model_type_id):
335
+ """
336
+ Sets the model configuration for local development.
337
+ This is used when running the model locally without uploading it to Clarifai.
338
+ """
339
+ if 'model' not in config:
340
+ config['model'] = {}
341
+ config["model"]["user_id"] = user_id
342
+ config["model"]["app_id"] = app_id
343
+ config["model"]["id"] = model_id
344
+ config["model"]["model_type_id"] = model_type_id
345
+ return config
346
+
347
+ def _validate_config(self):
348
+ if not self.download_validation_only:
349
+ self._validate_config_model()
350
+
351
+ assert "inference_compute_info" in self.config, (
352
+ "inference_compute_info not found in the config file"
353
+ )
266
354
 
267
- def _validate_config_model(self):
268
- assert "model" in self.config, "model section not found in the config file"
269
- model = self.config.get('model')
270
- assert "user_id" in model, "user_id not found in the config file"
271
- assert "app_id" in model, "app_id not found in the config file"
272
- assert "model_type_id" in model, "model_type_id not found in the config file"
273
- assert "id" in model, "model_id not found in the config file"
274
- if '.' in model.get('id'):
275
- logger.error(
276
- "Model ID cannot contain '.', please remove it from the model_id in the config file")
277
- sys.exit(1)
278
-
279
- assert model.get('user_id') != "", "user_id cannot be empty in the config file"
280
- assert model.get('app_id') != "", "app_id cannot be empty in the config file"
281
- assert model.get('model_type_id') != "", "model_type_id cannot be empty in the config file"
282
- assert model.get('id') != "", "model_id cannot be empty in the config file"
283
-
284
- if not self._check_app_exists():
285
- sys.exit(1)
286
-
287
- def _validate_config(self):
288
- if not self.download_validation_only:
289
- self._validate_config_model()
290
-
291
- assert "inference_compute_info" in self.config, "inference_compute_info not found in the config file"
292
-
293
- if self.config.get("concepts"):
294
- model_type_id = self.config.get('model').get('model_type_id')
295
- assert model_type_id in CONCEPTS_REQUIRED_MODEL_TYPE, f"Model type {model_type_id} not supported for concepts"
296
-
297
- if self.config.get("checkpoints"):
298
- loader_type, _, hf_token, _, _, _ = self._validate_config_checkpoints()
299
-
300
- if loader_type == "huggingface" and hf_token:
301
- is_valid_token = HuggingFaceLoader.validate_hftoken(hf_token)
302
- if not is_valid_token:
303
- logger.error(
304
- "Invalid Hugging Face token provided in the config file, this might cause issues with downloading the restricted model checkpoints."
305
- )
306
- logger.info("Continuing without Hugging Face token")
307
-
308
- num_threads = self.config.get("num_threads")
309
- if num_threads or num_threads == 0:
310
- assert isinstance(num_threads, int) and num_threads >= 1, ValueError(
311
- f"`num_threads` must be an integer greater than or equal to 1. Received type {type(num_threads)} with value {num_threads}."
312
- )
313
- else:
314
- num_threads = int(os.environ.get("CLARIFAI_NUM_THREADS", 16))
315
- self.config["num_threads"] = num_threads
355
+ if self.config.get("concepts"):
356
+ model_type_id = self.config.get('model').get('model_type_id')
357
+ assert model_type_id in CONCEPTS_REQUIRED_MODEL_TYPE, (
358
+ f"Model type {model_type_id} not supported for concepts"
359
+ )
360
+
361
+ if self.config.get("checkpoints"):
362
+ loader_type, _, hf_token, _, _, _ = self._validate_config_checkpoints()
363
+
364
+ if loader_type == "huggingface" and hf_token:
365
+ is_valid_token = HuggingFaceLoader.validate_hftoken(hf_token)
366
+ if not is_valid_token:
367
+ logger.error(
368
+ "Invalid Hugging Face token provided in the config file, this might cause issues with downloading the restricted model checkpoints."
369
+ )
370
+ logger.info("Continuing without Hugging Face token")
371
+
372
+ num_threads = self.config.get("num_threads")
373
+ if num_threads or num_threads == 0:
374
+ assert isinstance(num_threads, int) and num_threads >= 1, ValueError(
375
+ f"`num_threads` must be an integer greater than or equal to 1. Received type {type(num_threads)} with value {num_threads}."
376
+ )
377
+ else:
378
+ num_threads = int(os.environ.get("CLARIFAI_NUM_THREADS", 16))
379
+ self.config["num_threads"] = num_threads
380
+
381
+ @staticmethod
382
+ def _get_tar_file_content_size(tar_file_path):
383
+ """
384
+ Calculates the total size of the contents of a tar file.
385
+
386
+ Args:
387
+ tar_file_path (str): The path to the tar file.
388
+
389
+ Returns:
390
+ int: The total size of the contents in bytes.
391
+ """
392
+ total_size = 0
393
+ with tarfile.open(tar_file_path, 'r') as tar:
394
+ for member in tar:
395
+ if member.isfile():
396
+ total_size += member.size
397
+ return total_size
398
+
399
+ def method_signatures_yaml(self):
400
+ """
401
+ Returns the method signatures for the model class in YAML format.
402
+ """
403
+ model_class = self.load_model_class(mocking=True)
404
+ method_info = model_class._get_method_info()
405
+ signatures = {method.name: method.signature for method in method_info.values()}
406
+ return signatures_to_yaml(signatures)
407
+
408
+ def get_method_signatures(self):
409
+ """
410
+ Returns the method signatures for the model class.
411
+ """
412
+ model_class = self.load_model_class(mocking=True)
413
+ method_info = model_class._get_method_info()
414
+ signatures = [method.signature for method in method_info.values()]
415
+ return signatures
416
+
417
+ @property
418
+ def client(self):
419
+ if self._client is None:
420
+ assert "model" in self.config, "model info not found in the config file"
421
+ model = self.config.get('model')
422
+ assert "user_id" in model, "user_id not found in the config file"
423
+ assert "app_id" in model, "app_id not found in the config file"
424
+ # The owner of the model and the app.
425
+ user_id = model.get('user_id')
426
+ app_id = model.get('app_id')
427
+
428
+ self._base_api = os.environ.get('CLARIFAI_API_BASE', 'https://api.clarifai.com')
429
+ self._client = BaseClient(user_id=user_id, app_id=app_id, base=self._base_api)
430
+
431
+ return self._client
432
+
433
+ @property
434
+ def model_url(self):
435
+ url_helper = ClarifaiUrlHelper(self._client.auth_helper)
436
+ if self.model_version_id is not None:
437
+ return url_helper.clarifai_url(
438
+ self.client.user_app_id.user_id,
439
+ self.client.user_app_id.app_id,
440
+ "models",
441
+ self.model_id,
442
+ )
443
+ else:
444
+ return url_helper.clarifai_url(
445
+ self.client.user_app_id.user_id,
446
+ self.client.user_app_id.app_id,
447
+ "models",
448
+ self.model_id,
449
+ self.model_version_id,
450
+ )
316
451
 
317
- @staticmethod
318
- def _get_tar_file_content_size(tar_file_path):
319
- """
320
- Calculates the total size of the contents of a tar file.
452
+ def _get_model_proto(self):
453
+ assert "model" in self.config, "model info not found in the config file"
454
+ model = self.config.get('model')
321
455
 
322
- Args:
323
- tar_file_path (str): The path to the tar file.
456
+ assert "model_type_id" in model, "model_type_id not found in the config file"
457
+ assert "id" in model, "model_id not found in the config file"
458
+ if not self.download_validation_only:
459
+ assert "user_id" in model, "user_id not found in the config file"
460
+ assert "app_id" in model, "app_id not found in the config file"
324
461
 
325
- Returns:
326
- int: The total size of the contents in bytes.
327
- """
328
- total_size = 0
329
- with tarfile.open(tar_file_path, 'r') as tar:
330
- for member in tar:
331
- if member.isfile():
332
- total_size += member.size
333
- return total_size
334
-
335
- def method_signatures_yaml(self):
336
- """
337
- Returns the method signatures for the model class in YAML format.
338
- """
339
- model_class = self.load_model_class(mocking=True)
340
- method_info = model_class._get_method_info()
341
- signatures = {method.name: method.signature for method in method_info.values()}
342
- return signatures_to_yaml(signatures)
462
+ model_proto = json_format.ParseDict(model, resources_pb2.Model())
343
463
 
344
- def get_method_signatures(self):
345
- """
346
- Returns the method signatures for the model class.
347
- """
348
- model_class = self.load_model_class(mocking=True)
349
- method_info = model_class._get_method_info()
350
- signatures = [method.signature for method in method_info.values()]
351
- return signatures
352
-
353
- @property
354
- def client(self):
355
- if self._client is None:
356
- assert "model" in self.config, "model info not found in the config file"
357
- model = self.config.get('model')
358
- assert "user_id" in model, "user_id not found in the config file"
359
- assert "app_id" in model, "app_id not found in the config file"
360
- # The owner of the model and the app.
361
- user_id = model.get('user_id')
362
- app_id = model.get('app_id')
363
-
364
- self._base_api = os.environ.get('CLARIFAI_API_BASE', 'https://api.clarifai.com')
365
- self._client = BaseClient(user_id=user_id, app_id=app_id, base=self._base_api)
366
-
367
- return self._client
368
-
369
- @property
370
- def model_url(self):
371
- url_helper = ClarifaiUrlHelper(self._client.auth_helper)
372
- if self.model_version_id is not None:
373
- return url_helper.clarifai_url(self.client.user_app_id.user_id,
374
- self.client.user_app_id.app_id, "models", self.model_id)
375
- else:
376
- return url_helper.clarifai_url(self.client.user_app_id.user_id,
377
- self.client.user_app_id.app_id, "models", self.model_id,
378
- self.model_version_id)
379
-
380
- def _get_model_proto(self):
381
- assert "model" in self.config, "model info not found in the config file"
382
- model = self.config.get('model')
383
-
384
- assert "model_type_id" in model, "model_type_id not found in the config file"
385
- assert "id" in model, "model_id not found in the config file"
386
- if not self.download_validation_only:
387
- assert "user_id" in model, "user_id not found in the config file"
388
- assert "app_id" in model, "app_id not found in the config file"
389
-
390
- model_proto = json_format.ParseDict(model, resources_pb2.Model())
391
-
392
- return model_proto
393
-
394
- def _get_inference_compute_info(self):
395
- assert ("inference_compute_info" in self.config
396
- ), "inference_compute_info not found in the config file"
397
- inference_compute_info = self.config.get('inference_compute_info')
398
- return json_format.ParseDict(inference_compute_info, resources_pb2.ComputeInfo())
399
-
400
- def check_model_exists(self):
401
- resp = self.client.STUB.GetModel(
402
- service_pb2.GetModelRequest(
403
- user_app_id=self.client.user_app_id, model_id=self.model_proto.id))
404
- if resp.status.code == status_code_pb2.SUCCESS:
405
- return True
406
- return False
464
+ return model_proto
407
465
 
408
- def maybe_create_model(self):
409
- if self.check_model_exists():
410
- logger.info(
411
- f"Model '{self.client.user_app_id.user_id}/{self.client.user_app_id.app_id}/models/{self.model_proto.id}' already exists, "
412
- f"will create a new version for it.")
413
- return
414
-
415
- request = service_pb2.PostModelsRequest(
416
- user_app_id=self.client.user_app_id,
417
- models=[self.model_proto],
418
- )
419
- return self.client.STUB.PostModels(request)
420
-
421
- def _match_req_line(self, line):
422
- line = line.strip()
423
- if not line or line.startswith('#'):
424
- return None, None
425
- # split on whitespace followed by #
426
- line = re.split(r'\s+#', line)[0]
427
- if "==" in line:
428
- pkg, version = line.split("==")
429
- elif ">=" in line:
430
- pkg, version = line.split(">=")
431
- elif ">" in line:
432
- pkg, version = line.split(">")
433
- elif "<=" in line:
434
- pkg, version = line.split("<=")
435
- elif "<" in line:
436
- pkg, version = line.split("<")
437
- else:
438
- pkg, version = line, None # No version specified
439
- for dep in dependencies:
440
- if dep == pkg:
441
- if dep == 'torch' and line.find(
442
- 'whl/cpu') > 0: # Ignore torch-cpu whl files, use base mage.
443
- return None, None
444
- return dep.strip(), version.strip() if version else None
445
- return None, None
446
-
447
- def _parse_requirements(self):
448
- dependencies_version = {}
449
- with open(os.path.join(self.folder, 'requirements.txt'), 'r') as file:
450
- for line in file:
451
- # Skip empty lines and comments
452
- dependency, version = self._match_req_line(line)
453
- if dependency is None:
454
- continue
455
- dependencies_version[dependency] = version if version else None
456
- return dependencies_version
457
-
458
- def create_dockerfile(self):
459
- dockerfile_template = os.path.join(
460
- os.path.dirname(os.path.dirname(__file__)),
461
- 'dockerfile_template',
462
- 'Dockerfile.template',
463
- )
464
-
465
- with open(dockerfile_template, 'r') as template_file:
466
- dockerfile_template = template_file.read()
467
-
468
- dockerfile_template = Template(dockerfile_template)
469
-
470
- # Get the Python version from the config file
471
- build_info = self.config.get('build_info', {})
472
- if 'python_version' in build_info:
473
- python_version = build_info['python_version']
474
- if python_version not in AVAILABLE_PYTHON_IMAGES:
475
- raise Exception(
476
- f"Python version {python_version} not supported, please use one of the following versions: {AVAILABLE_PYTHON_IMAGES} in your config.yaml"
466
+ def _get_inference_compute_info(self):
467
+ assert "inference_compute_info" in self.config, (
468
+ "inference_compute_info not found in the config file"
477
469
  )
470
+ inference_compute_info = self.config.get('inference_compute_info')
471
+ return json_format.ParseDict(inference_compute_info, resources_pb2.ComputeInfo())
478
472
 
479
- logger.info(
480
- f"Using Python version {python_version} from the config file to build the Dockerfile")
481
- else:
482
- logger.info(
483
- f"Python version not found in the config file, using default Python version: {DEFAULT_PYTHON_VERSION}"
484
- )
485
- python_version = DEFAULT_PYTHON_VERSION
486
-
487
- # This is always the final image used for runtime.
488
- final_image = PYTHON_BASE_IMAGE.format(python_version=python_version)
489
- downloader_image = PYTHON_BASE_IMAGE.format(python_version=python_version)
490
-
491
- # Parse the requirements.txt file to determine the base image
492
- dependencies = self._parse_requirements()
493
- if 'torch' in dependencies and dependencies['torch']:
494
- torch_version = dependencies['torch']
495
-
496
- # Sort in reverse so that newer cuda versions come first and are preferred.
497
- for image in sorted(AVAILABLE_TORCH_IMAGES, reverse=True):
498
- if torch_version in image and f'py{python_version}' in image:
499
- # like cu124, rocm6.3, etc.
500
- gpu_version = image.split('-')[-1]
501
- final_image = TORCH_BASE_IMAGE.format(
502
- torch_version=torch_version,
503
- python_version=python_version,
504
- gpu_version=gpu_version,
505
- )
506
- logger.info(f"Using Torch version {torch_version} base image to build the Docker image")
507
- break
508
-
509
- if 'clarifai' not in dependencies:
510
- raise Exception(
511
- f"clarifai not found in requirements.txt, please add clarifai to the requirements.txt file with a fixed version. Current version is clarifai=={CLIENT_VERSION}"
512
- )
513
- clarifai_version = dependencies['clarifai']
514
- if not clarifai_version:
515
- logger.warn(
516
- f"clarifai version not found in requirements.txt, using the latest version {CLIENT_VERSION}"
517
- )
518
- clarifai_version = CLIENT_VERSION
519
- lines = []
520
- with open(os.path.join(self.folder, 'requirements.txt'), 'r') as file:
521
- for line in file:
522
- # if the line without whitespace is "clarifai"
523
- dependency, version = self._match_req_line(line)
524
- if dependency and dependency == "clarifai":
525
- lines.append(line.replace("clarifai", f"clarifai=={CLIENT_VERSION}"))
526
- else:
527
- lines.append(line)
528
- with open(os.path.join(self.folder, 'requirements.txt'), 'w') as file:
529
- file.writelines(lines)
530
- logger.warn(f"Updated requirements.txt to have clarifai=={CLIENT_VERSION}")
531
-
532
- # Replace placeholders with actual values
533
- dockerfile_content = dockerfile_template.safe_substitute(
534
- name='main',
535
- FINAL_IMAGE=final_image, # for pip requirements
536
- DOWNLOADER_IMAGE=downloader_image, # for downloading checkpoints
537
- CLARIFAI_VERSION=clarifai_version, # for clarifai
538
- )
539
-
540
- # Write Dockerfile
541
- with open(os.path.join(self.folder, 'Dockerfile'), 'w') as dockerfile:
542
- dockerfile.write(dockerfile_content)
543
-
544
- @property
545
- def checkpoint_path(self):
546
- return self._checkpoint_path(self.folder)
547
-
548
- def _checkpoint_path(self, folder):
549
- return os.path.join(folder, self.checkpoint_suffix)
550
-
551
- @property
552
- def checkpoint_suffix(self):
553
- return os.path.join('1', 'checkpoints')
554
-
555
- @property
556
- def tar_file(self):
557
- return f"{self.folder}.tar.gz"
558
-
559
- def default_runtime_checkpoint_path(self):
560
- return DEFAULT_RUNTIME_DOWNLOAD_PATH
561
-
562
- def download_checkpoints(self,
563
- stage: str = DEFAULT_DOWNLOAD_CHECKPOINT_WHEN,
564
- checkpoint_path_override: str = None):
565
- """
566
- Downloads the checkpoints specified in the config file.
473
+ def check_model_exists(self):
474
+ resp = self.client.STUB.GetModel(
475
+ service_pb2.GetModelRequest(
476
+ user_app_id=self.client.user_app_id, model_id=self.model_proto.id
477
+ )
478
+ )
479
+ if resp.status.code == status_code_pb2.SUCCESS:
480
+ return True
481
+ return False
567
482
 
568
- :param stage: The stage of the build process. This is used to determine when to download the
569
- checkpoints. The stage can be one of ['build', 'upload', 'runtime']. If you want to force
570
- downloading now then set stage to match e when field of the checkpoints section of you config.yaml.
571
- :param checkpoint_path_override: The path to download the checkpoints to (with 1/checkpoints added as suffix). If not provided, the
572
- default path is used based on the folder ModelUploader was initialized with. The checkpoint_suffix will be appended to the path.
573
- If stage is 'runtime' and checkpoint_path_override is None, the default runtime path will be used.
483
+ def maybe_create_model(self):
484
+ if self.check_model_exists():
485
+ logger.info(
486
+ f"Model '{self.client.user_app_id.user_id}/{self.client.user_app_id.app_id}/models/{self.model_proto.id}' already exists, "
487
+ f"will create a new version for it."
488
+ )
489
+ return
574
490
 
575
- :return: The path to the downloaded checkpoints. Even if it doesn't download anything, it will return the default path.
576
- """
577
- path = self.checkpoint_path # default checkpoint path.
578
- if not self.config.get("checkpoints"):
579
- logger.info("No checkpoints specified in the config file")
580
- return path
581
- clarifai_model_type_id = self.config.get('model').get('model_type_id')
582
-
583
- loader_type, repo_id, hf_token, when, allowed_file_patterns, ignore_file_patterns = self._validate_config_checkpoints(
584
- )
585
- if stage not in ["build", "upload", "runtime"]:
586
- raise Exception("Invalid stage provided, must be one of ['build', 'upload', 'runtime']")
587
- if when != stage:
588
- logger.info(
589
- f"Skipping downloading checkpoints for stage {stage} since config.yaml says to download them at stage {when}"
590
- )
591
- return path
592
-
593
- success = False
594
- if loader_type == "huggingface":
595
- loader = HuggingFaceLoader(
596
- repo_id=repo_id, token=hf_token, model_type_id=clarifai_model_type_id)
597
- # for runtime default to /tmp path
598
- if stage == "runtime" and checkpoint_path_override is None:
599
- checkpoint_path_override = self.default_runtime_checkpoint_path()
600
- path = checkpoint_path_override if checkpoint_path_override else self.checkpoint_path
601
- success = loader.download_checkpoints(
602
- path,
603
- allowed_file_patterns=allowed_file_patterns,
604
- ignore_file_patterns=ignore_file_patterns)
605
-
606
- if loader_type:
607
- if not success:
608
- logger.error(f"Failed to download checkpoints for model {repo_id}")
609
- sys.exit(1)
610
- else:
611
- logger.info(f"Downloaded checkpoints for model {repo_id}")
612
- return path
613
-
614
- def _concepts_protos_from_concepts(self, concepts):
615
- concept_protos = []
616
- for concept in concepts:
617
- concept_protos.append(resources_pb2.Concept(
618
- id=str(concept[0]),
619
- name=concept[1],
620
- ))
621
- return concept_protos
622
-
623
- def hf_labels_to_config(self, labels, config_file):
624
- with open(config_file, 'r') as file:
625
- config = yaml.safe_load(file)
626
- model = config.get('model')
627
- model_type_id = model.get('model_type_id')
628
- assert model_type_id in CONCEPTS_REQUIRED_MODEL_TYPE, f"Model type {model_type_id} not supported for concepts"
629
- concept_protos = self._concepts_protos_from_concepts(labels)
630
-
631
- config['concepts'] = [{'id': concept.id, 'name': concept.name} for concept in concept_protos]
632
-
633
- with open(config_file, 'w') as file:
634
- yaml.dump(config, file, sort_keys=False)
635
- concepts = config.get('concepts')
636
- logger.info(f"Updated config.yaml with {len(concepts)} concepts.")
637
-
638
- def get_model_version_proto(self):
639
- signatures = self.get_method_signatures()
640
- model_version_proto = resources_pb2.ModelVersion(
641
- pretrained_model_config=resources_pb2.PretrainedModelConfig(),
642
- inference_compute_info=self.inference_compute_info,
643
- method_signatures=signatures,
644
- )
645
-
646
- model_type_id = self.config.get('model').get('model_type_id')
647
- if model_type_id in CONCEPTS_REQUIRED_MODEL_TYPE:
648
-
649
- if 'concepts' in self.config:
650
- labels = self.config.get('concepts')
651
- logger.info(f"Found {len(labels)} concepts in the config file.")
652
- for concept in labels:
653
- concept_proto = json_format.ParseDict(concept, resources_pb2.Concept())
654
- model_version_proto.output_info.data.concepts.append(concept_proto)
655
- elif self.config.get("checkpoints") and HuggingFaceLoader.validate_concept(
656
- self.checkpoint_path):
657
- labels = HuggingFaceLoader.fetch_labels(self.checkpoint_path)
658
- logger.info(f"Found {len(labels)} concepts from the model checkpoints.")
659
- # sort the concepts by id and then update the config file
660
- labels = sorted(labels.items(), key=lambda x: int(x[0]))
661
-
662
- config_file = os.path.join(self.folder, 'config.yaml')
663
- try:
664
- self.hf_labels_to_config(labels, config_file)
665
- except Exception as e:
666
- logger.error(f"Failed to update the config.yaml file with the concepts: {e}")
491
+ request = service_pb2.PostModelsRequest(
492
+ user_app_id=self.client.user_app_id,
493
+ models=[self.model_proto],
494
+ )
495
+ return self.client.STUB.PostModels(request)
496
+
497
+ def _match_req_line(self, line):
498
+ line = line.strip()
499
+ if not line or line.startswith('#'):
500
+ return None, None
501
+ # split on whitespace followed by #
502
+ line = re.split(r'\s+#', line)[0]
503
+ if "==" in line:
504
+ pkg, version = line.split("==")
505
+ elif ">=" in line:
506
+ pkg, version = line.split(">=")
507
+ elif ">" in line:
508
+ pkg, version = line.split(">")
509
+ elif "<=" in line:
510
+ pkg, version = line.split("<=")
511
+ elif "<" in line:
512
+ pkg, version = line.split("<")
513
+ else:
514
+ pkg, version = line, None # No version specified
515
+ for dep in dependencies:
516
+ if dep == pkg:
517
+ if (
518
+ dep == 'torch' and line.find('whl/cpu') > 0
519
+ ): # Ignore torch-cpu whl files, use base mage.
520
+ return None, None
521
+ return dep.strip(), version.strip() if version else None
522
+ return None, None
523
+
524
+ def _parse_requirements(self):
525
+ dependencies_version = {}
526
+ with open(os.path.join(self.folder, 'requirements.txt'), 'r') as file:
527
+ for line in file:
528
+ # Skip empty lines and comments
529
+ dependency, version = self._match_req_line(line)
530
+ if dependency is None:
531
+ continue
532
+ dependencies_version[dependency] = version if version else None
533
+ return dependencies_version
534
+
535
+ def create_dockerfile(self):
536
+ dockerfile_template = os.path.join(
537
+ os.path.dirname(os.path.dirname(__file__)),
538
+ 'dockerfile_template',
539
+ 'Dockerfile.template',
540
+ )
667
541
 
668
- model_version_proto.output_info.data.concepts.extend(
669
- self._concepts_protos_from_concepts(labels))
670
- return model_version_proto
542
+ with open(dockerfile_template, 'r') as template_file:
543
+ dockerfile_template = template_file.read()
671
544
 
672
- def upload_model_version(self):
673
- file_path = f"{self.folder}.tar.gz"
674
- logger.debug(f"Will tar it into file: {file_path}")
545
+ dockerfile_template = Template(dockerfile_template)
675
546
 
676
- model_type_id = self.config.get('model').get('model_type_id')
677
- loader_type, repo_id, hf_token, when, _, _ = self._validate_config_checkpoints()
547
+ # Get the Python version from the config file
548
+ build_info = self.config.get('build_info', {})
549
+ if 'python_version' in build_info:
550
+ python_version = build_info['python_version']
551
+ if python_version not in AVAILABLE_PYTHON_IMAGES:
552
+ raise Exception(
553
+ f"Python version {python_version} not supported, please use one of the following versions: {AVAILABLE_PYTHON_IMAGES} in your config.yaml"
554
+ )
678
555
 
679
- if (model_type_id in CONCEPTS_REQUIRED_MODEL_TYPE) and 'concepts' not in self.config:
680
- logger.info(
681
- f"Model type {model_type_id} requires concepts to be specified in the config.yaml file.."
682
- )
683
- if self.config.get("checkpoints"):
684
- logger.info(
685
- "Checkpoints specified in the config.yaml file, will download the HF model's config.json file to infer the concepts."
556
+ logger.info(
557
+ f"Using Python version {python_version} from the config file to build the Dockerfile"
558
+ )
559
+ else:
560
+ logger.info(
561
+ f"Python version not found in the config file, using default Python version: {DEFAULT_PYTHON_VERSION}"
562
+ )
563
+ python_version = DEFAULT_PYTHON_VERSION
564
+
565
+ # This is always the final image used for runtime.
566
+ final_image = PYTHON_BASE_IMAGE.format(python_version=python_version)
567
+ downloader_image = PYTHON_BASE_IMAGE.format(python_version=python_version)
568
+
569
+ # Parse the requirements.txt file to determine the base image
570
+ dependencies = self._parse_requirements()
571
+ if 'torch' in dependencies and dependencies['torch']:
572
+ torch_version = dependencies['torch']
573
+
574
+ # Sort in reverse so that newer cuda versions come first and are preferred.
575
+ for image in sorted(AVAILABLE_TORCH_IMAGES, reverse=True):
576
+ if torch_version in image and f'py{python_version}' in image:
577
+ # like cu124, rocm6.3, etc.
578
+ gpu_version = image.split('-')[-1]
579
+ final_image = TORCH_BASE_IMAGE.format(
580
+ torch_version=torch_version,
581
+ python_version=python_version,
582
+ gpu_version=gpu_version,
583
+ )
584
+ logger.info(
585
+ f"Using Torch version {torch_version} base image to build the Docker image"
586
+ )
587
+ break
588
+
589
+ if 'clarifai' not in dependencies:
590
+ raise Exception(
591
+ f"clarifai not found in requirements.txt, please add clarifai to the requirements.txt file with a fixed version. Current version is clarifai=={CLIENT_VERSION}"
592
+ )
593
+ clarifai_version = dependencies['clarifai']
594
+ if not clarifai_version:
595
+ logger.warn(
596
+ f"clarifai version not found in requirements.txt, using the latest version {CLIENT_VERSION}"
597
+ )
598
+ clarifai_version = CLIENT_VERSION
599
+ lines = []
600
+ with open(os.path.join(self.folder, 'requirements.txt'), 'r') as file:
601
+ for line in file:
602
+ # if the line without whitespace is "clarifai"
603
+ dependency, version = self._match_req_line(line)
604
+ if dependency and dependency == "clarifai":
605
+ lines.append(line.replace("clarifai", f"clarifai=={CLIENT_VERSION}"))
606
+ else:
607
+ lines.append(line)
608
+ with open(os.path.join(self.folder, 'requirements.txt'), 'w') as file:
609
+ file.writelines(lines)
610
+ logger.warn(f"Updated requirements.txt to have clarifai=={CLIENT_VERSION}")
611
+
612
+ # Replace placeholders with actual values
613
+ dockerfile_content = dockerfile_template.safe_substitute(
614
+ name='main',
615
+ FINAL_IMAGE=final_image, # for pip requirements
616
+ DOWNLOADER_IMAGE=downloader_image, # for downloading checkpoints
617
+ CLARIFAI_VERSION=clarifai_version, # for clarifai
618
+ )
619
+
620
+ # Write Dockerfile
621
+ with open(os.path.join(self.folder, 'Dockerfile'), 'w') as dockerfile:
622
+ dockerfile.write(dockerfile_content)
623
+
624
+ @property
625
+ def checkpoint_path(self):
626
+ return self._checkpoint_path(self.folder)
627
+
628
+ def _checkpoint_path(self, folder):
629
+ return os.path.join(folder, self.checkpoint_suffix)
630
+
631
+ @property
632
+ def checkpoint_suffix(self):
633
+ return os.path.join('1', 'checkpoints')
634
+
635
+ @property
636
+ def tar_file(self):
637
+ return f"{self.folder}.tar.gz"
638
+
639
+ def default_runtime_checkpoint_path(self):
640
+ return DEFAULT_RUNTIME_DOWNLOAD_PATH
641
+
642
+ def download_checkpoints(
643
+ self, stage: str = DEFAULT_DOWNLOAD_CHECKPOINT_WHEN, checkpoint_path_override: str = None
644
+ ):
645
+ """
646
+ Downloads the checkpoints specified in the config file.
647
+
648
+ :param stage: The stage of the build process. This is used to determine when to download the
649
+ checkpoints. The stage can be one of ['build', 'upload', 'runtime']. If you want to force
650
+ downloading now then set stage to match e when field of the checkpoints section of you config.yaml.
651
+ :param checkpoint_path_override: The path to download the checkpoints to (with 1/checkpoints added as suffix). If not provided, the
652
+ default path is used based on the folder ModelUploader was initialized with. The checkpoint_suffix will be appended to the path.
653
+ If stage is 'runtime' and checkpoint_path_override is None, the default runtime path will be used.
654
+
655
+ :return: The path to the downloaded checkpoints. Even if it doesn't download anything, it will return the default path.
656
+ """
657
+ path = self.checkpoint_path # default checkpoint path.
658
+ if not self.config.get("checkpoints"):
659
+ logger.info("No checkpoints specified in the config file")
660
+ return path
661
+ clarifai_model_type_id = self.config.get('model').get('model_type_id')
662
+
663
+ loader_type, repo_id, hf_token, when, allowed_file_patterns, ignore_file_patterns = (
664
+ self._validate_config_checkpoints()
686
665
  )
687
- # If we don't already have the concepts, download the config.json file from HuggingFace
666
+ if stage not in ["build", "upload", "runtime"]:
667
+ raise Exception(
668
+ "Invalid stage provided, must be one of ['build', 'upload', 'runtime']"
669
+ )
670
+ if when != stage:
671
+ logger.info(
672
+ f"Skipping downloading checkpoints for stage {stage} since config.yaml says to download them at stage {when}"
673
+ )
674
+ return path
675
+
676
+ success = False
688
677
  if loader_type == "huggingface":
689
- # If the config.yaml says we'll download in the future (build time or runtime) then we need to get this config now.
690
- if when != "upload" and not HuggingFaceLoader.validate_config(self.checkpoint_path):
691
- input(
692
- "Press Enter to download the HuggingFace model's config.json file to infer the concepts and continue..."
678
+ loader = HuggingFaceLoader(
679
+ repo_id=repo_id, token=hf_token, model_type_id=clarifai_model_type_id
680
+ )
681
+ # for runtime default to /tmp path
682
+ if stage == "runtime" and checkpoint_path_override is None:
683
+ checkpoint_path_override = self.default_runtime_checkpoint_path()
684
+ path = checkpoint_path_override if checkpoint_path_override else self.checkpoint_path
685
+ success = loader.download_checkpoints(
686
+ path,
687
+ allowed_file_patterns=allowed_file_patterns,
688
+ ignore_file_patterns=ignore_file_patterns,
693
689
  )
694
- loader = HuggingFaceLoader(repo_id=repo_id, token=hf_token)
695
- loader.download_config(self.checkpoint_path)
696
690
 
697
- else:
698
- logger.error(
699
- "No checkpoints specified in the config.yaml file to infer the concepts. Please either specify the concepts directly in the config.yaml file or include a checkpoints section to download the HF model's config.json file to infer the concepts."
691
+ if loader_type:
692
+ if not success:
693
+ logger.error(f"Failed to download checkpoints for model {repo_id}")
694
+ sys.exit(1)
695
+ else:
696
+ logger.info(f"Downloaded checkpoints for model {repo_id}")
697
+ return path
698
+
699
+ def _concepts_protos_from_concepts(self, concepts):
700
+ concept_protos = []
701
+ for concept in concepts:
702
+ concept_protos.append(
703
+ resources_pb2.Concept(
704
+ id=str(concept[0]),
705
+ name=concept[1],
706
+ )
707
+ )
708
+ return concept_protos
709
+
710
+ def hf_labels_to_config(self, labels, config_file):
711
+ with open(config_file, 'r') as file:
712
+ config = yaml.safe_load(file)
713
+ model = config.get('model')
714
+ model_type_id = model.get('model_type_id')
715
+ assert model_type_id in CONCEPTS_REQUIRED_MODEL_TYPE, (
716
+ f"Model type {model_type_id} not supported for concepts"
700
717
  )
701
- return
702
-
703
- model_version_proto = self.get_model_version_proto()
704
-
705
- def filter_func(tarinfo):
706
- name = tarinfo.name
707
- exclude = [self.tar_file, "*~", "*.pyc", "*.pyo", "__pycache__"]
708
- if when != "upload":
709
- exclude.append(self.checkpoint_suffix)
710
- return None if any(name.endswith(ex) for ex in exclude) else tarinfo
711
-
712
- with tarfile.open(self.tar_file, "w:gz") as tar:
713
- tar.add(self.folder, arcname=".", filter=filter_func)
714
- logger.debug("Tarring complete, about to start upload.")
715
-
716
- file_size = os.path.getsize(self.tar_file)
717
- logger.debug(f"Size of the tar is: {file_size} bytes")
718
-
719
- self.storage_request_size = self._get_tar_file_content_size(file_path)
720
- if when != "upload" and self.config.get("checkpoints"):
721
- # Get the checkpoint size to add to the storage request.
722
- # First check for the env variable, then try querying huggingface. If all else fails, use the default.
723
- checkpoint_size = os.environ.get('CHECKPOINT_SIZE_BYTES', 0)
724
- if not checkpoint_size:
725
- _, repo_id, _, _, _, _ = self._validate_config_checkpoints()
726
- checkpoint_size = HuggingFaceLoader.get_huggingface_checkpoint_total_size(repo_id)
727
- if not checkpoint_size:
728
- checkpoint_size = self.DEFAULT_CHECKPOINT_SIZE
729
- self.storage_request_size += checkpoint_size
730
-
731
- resp = self.maybe_create_model()
732
- if not self.check_model_exists():
733
- logger.error(f"Failed to create model: {self.model_proto.id}. Details: {resp}")
734
- sys.exit(1)
735
-
736
- for response in self.client.STUB.PostModelVersionsUpload(
737
- self.model_version_stream_upload_iterator(model_version_proto, file_path),):
738
- percent_completed = 0
739
- if response.status.code == status_code_pb2.UPLOAD_IN_PROGRESS:
740
- percent_completed = response.status.percent_completed
741
- details = response.status.details
742
-
743
- _clear_line()
744
- print(
745
- f"Status: {response.status.description}, "
746
- f"Progress: {percent_completed}% - {details} ",
747
- f"request_id: {response.status.req_id}",
748
- end='\r',
749
- flush=True)
750
- if response.status.code != status_code_pb2.MODEL_BUILDING:
751
- logger.error(f"Failed to upload model version: {response}")
752
- return
753
- self.model_version_id = response.model_version_id
754
- logger.info(f"Created Model Version ID: {self.model_version_id}")
755
- logger.info(f"Full url to that version is: {self.model_url}")
756
- try:
757
- self.monitor_model_build()
758
- finally:
759
- if os.path.exists(self.tar_file):
760
- logger.debug(f"Cleaning up upload file: {self.tar_file}")
761
- os.remove(self.tar_file)
762
-
763
- def model_version_stream_upload_iterator(self, model_version_proto, file_path):
764
- yield self.init_upload_model_version(model_version_proto, file_path)
765
- with open(file_path, "rb") as f:
766
- file_size = os.path.getsize(file_path)
767
- chunk_size = int(127 * 1024 * 1024) # 127MB chunk size
768
- num_chunks = (file_size // chunk_size) + 1
769
- logger.info("Uploading file...")
770
- logger.debug(f"File size: {file_size}")
771
- logger.debug(f"Chunk size: {chunk_size}")
772
- logger.debug(f"Number of chunks: {num_chunks}")
773
- read_so_far = 0
774
- for part_id in range(num_chunks):
718
+ concept_protos = self._concepts_protos_from_concepts(labels)
719
+
720
+ config['concepts'] = [
721
+ {'id': concept.id, 'name': concept.name} for concept in concept_protos
722
+ ]
723
+
724
+ with open(config_file, 'w') as file:
725
+ yaml.dump(config, file, sort_keys=False)
726
+ concepts = config.get('concepts')
727
+ logger.info(f"Updated config.yaml with {len(concepts)} concepts.")
728
+
729
+ def get_model_version_proto(self):
730
+ signatures = self.get_method_signatures()
731
+ model_version_proto = resources_pb2.ModelVersion(
732
+ pretrained_model_config=resources_pb2.PretrainedModelConfig(),
733
+ inference_compute_info=self.inference_compute_info,
734
+ method_signatures=signatures,
735
+ )
736
+
737
+ model_type_id = self.config.get('model').get('model_type_id')
738
+ if model_type_id in CONCEPTS_REQUIRED_MODEL_TYPE:
739
+ if 'concepts' in self.config:
740
+ labels = self.config.get('concepts')
741
+ logger.info(f"Found {len(labels)} concepts in the config file.")
742
+ for concept in labels:
743
+ concept_proto = json_format.ParseDict(concept, resources_pb2.Concept())
744
+ model_version_proto.output_info.data.concepts.append(concept_proto)
745
+ elif self.config.get("checkpoints") and HuggingFaceLoader.validate_concept(
746
+ self.checkpoint_path
747
+ ):
748
+ labels = HuggingFaceLoader.fetch_labels(self.checkpoint_path)
749
+ logger.info(f"Found {len(labels)} concepts from the model checkpoints.")
750
+ # sort the concepts by id and then update the config file
751
+ labels = sorted(labels.items(), key=lambda x: int(x[0]))
752
+
753
+ config_file = os.path.join(self.folder, 'config.yaml')
754
+ try:
755
+ self.hf_labels_to_config(labels, config_file)
756
+ except Exception as e:
757
+ logger.error(f"Failed to update the config.yaml file with the concepts: {e}")
758
+
759
+ model_version_proto.output_info.data.concepts.extend(
760
+ self._concepts_protos_from_concepts(labels)
761
+ )
762
+ return model_version_proto
763
+
764
+ def upload_model_version(self):
765
+ file_path = f"{self.folder}.tar.gz"
766
+ logger.debug(f"Will tar it into file: {file_path}")
767
+
768
+ model_type_id = self.config.get('model').get('model_type_id')
769
+ loader_type, repo_id, hf_token, when, _, _ = self._validate_config_checkpoints()
770
+
771
+ if (model_type_id in CONCEPTS_REQUIRED_MODEL_TYPE) and 'concepts' not in self.config:
772
+ logger.info(
773
+ f"Model type {model_type_id} requires concepts to be specified in the config.yaml file.."
774
+ )
775
+ if self.config.get("checkpoints"):
776
+ logger.info(
777
+ "Checkpoints specified in the config.yaml file, will download the HF model's config.json file to infer the concepts."
778
+ )
779
+ # If we don't already have the concepts, download the config.json file from HuggingFace
780
+ if loader_type == "huggingface":
781
+ # If the config.yaml says we'll download in the future (build time or runtime) then we need to get this config now.
782
+ if when != "upload" and not HuggingFaceLoader.validate_config(
783
+ self.checkpoint_path
784
+ ):
785
+ input(
786
+ "Press Enter to download the HuggingFace model's config.json file to infer the concepts and continue..."
787
+ )
788
+ loader = HuggingFaceLoader(repo_id=repo_id, token=hf_token)
789
+ loader.download_config(self.checkpoint_path)
790
+
791
+ else:
792
+ logger.error(
793
+ "No checkpoints specified in the config.yaml file to infer the concepts. Please either specify the concepts directly in the config.yaml file or include a checkpoints section to download the HF model's config.json file to infer the concepts."
794
+ )
795
+ return
796
+
797
+ model_version_proto = self.get_model_version_proto()
798
+
799
+ def filter_func(tarinfo):
800
+ name = tarinfo.name
801
+ exclude = [self.tar_file, "*~", "*.pyc", "*.pyo", "__pycache__"]
802
+ if when != "upload":
803
+ exclude.append(self.checkpoint_suffix)
804
+ return None if any(name.endswith(ex) for ex in exclude) else tarinfo
805
+
806
+ with tarfile.open(self.tar_file, "w:gz") as tar:
807
+ tar.add(self.folder, arcname=".", filter=filter_func)
808
+ logger.debug("Tarring complete, about to start upload.")
809
+
810
+ file_size = os.path.getsize(self.tar_file)
811
+ logger.debug(f"Size of the tar is: {file_size} bytes")
812
+
813
+ self.storage_request_size = self._get_tar_file_content_size(file_path)
814
+ if when != "upload" and self.config.get("checkpoints"):
815
+ # Get the checkpoint size to add to the storage request.
816
+ # First check for the env variable, then try querying huggingface. If all else fails, use the default.
817
+ checkpoint_size = os.environ.get('CHECKPOINT_SIZE_BYTES', 0)
818
+ if not checkpoint_size:
819
+ _, repo_id, _, _, _, _ = self._validate_config_checkpoints()
820
+ checkpoint_size = HuggingFaceLoader.get_huggingface_checkpoint_total_size(repo_id)
821
+ if not checkpoint_size:
822
+ checkpoint_size = self.DEFAULT_CHECKPOINT_SIZE
823
+ self.storage_request_size += checkpoint_size
824
+
825
+ resp = self.maybe_create_model()
826
+ if not self.check_model_exists():
827
+ logger.error(f"Failed to create model: {self.model_proto.id}. Details: {resp}")
828
+ sys.exit(1)
829
+
830
+ for response in self.client.STUB.PostModelVersionsUpload(
831
+ self.model_version_stream_upload_iterator(model_version_proto, file_path),
832
+ ):
833
+ percent_completed = 0
834
+ if response.status.code == status_code_pb2.UPLOAD_IN_PROGRESS:
835
+ percent_completed = response.status.percent_completed
836
+ details = response.status.details
837
+
838
+ _clear_line()
839
+ print(
840
+ f"Status: {response.status.description}, Progress: {percent_completed}% - {details} ",
841
+ f"request_id: {response.status.req_id}",
842
+ end='\r',
843
+ flush=True,
844
+ )
845
+ if response.status.code != status_code_pb2.MODEL_BUILDING:
846
+ logger.error(f"Failed to upload model version: {response}")
847
+ return
848
+ self.model_version_id = response.model_version_id
849
+ logger.info(f"Created Model Version ID: {self.model_version_id}")
850
+ logger.info(f"Full url to that version is: {self.model_url}")
775
851
  try:
776
- chunk_size = min(chunk_size, file_size - read_so_far)
777
- chunk = f.read(chunk_size)
778
- if not chunk:
779
- break
780
- read_so_far += len(chunk)
781
- yield service_pb2.PostModelVersionsUploadRequest(
782
- content_part=resources_pb2.UploadContentPart(
783
- data=chunk,
784
- part_number=part_id + 1,
785
- range_start=read_so_far,
786
- ))
787
- except Exception as e:
788
- logger.exception(f"\nError uploading file: {e}")
789
- break
790
-
791
- if read_so_far == file_size:
792
- logger.info("Upload complete!")
793
-
794
- def init_upload_model_version(self, model_version_proto, file_path):
795
- file_size = os.path.getsize(file_path)
796
- logger.debug(f"Uploading model version of model {self.model_proto.id}")
797
- logger.debug(f"Using file '{os.path.basename(file_path)}' of size: {file_size} bytes")
798
- result = service_pb2.PostModelVersionsUploadRequest(
799
- upload_config=service_pb2.PostModelVersionsUploadConfig(
852
+ self.monitor_model_build()
853
+ finally:
854
+ if os.path.exists(self.tar_file):
855
+ logger.debug(f"Cleaning up upload file: {self.tar_file}")
856
+ os.remove(self.tar_file)
857
+
858
+ def model_version_stream_upload_iterator(self, model_version_proto, file_path):
859
+ yield self.init_upload_model_version(model_version_proto, file_path)
860
+ with open(file_path, "rb") as f:
861
+ file_size = os.path.getsize(file_path)
862
+ chunk_size = int(127 * 1024 * 1024) # 127MB chunk size
863
+ num_chunks = (file_size // chunk_size) + 1
864
+ logger.info("Uploading file...")
865
+ logger.debug(f"File size: {file_size}")
866
+ logger.debug(f"Chunk size: {chunk_size}")
867
+ logger.debug(f"Number of chunks: {num_chunks}")
868
+ read_so_far = 0
869
+ for part_id in range(num_chunks):
870
+ try:
871
+ chunk_size = min(chunk_size, file_size - read_so_far)
872
+ chunk = f.read(chunk_size)
873
+ if not chunk:
874
+ break
875
+ read_so_far += len(chunk)
876
+ yield service_pb2.PostModelVersionsUploadRequest(
877
+ content_part=resources_pb2.UploadContentPart(
878
+ data=chunk,
879
+ part_number=part_id + 1,
880
+ range_start=read_so_far,
881
+ )
882
+ )
883
+ except Exception as e:
884
+ logger.exception(f"\nError uploading file: {e}")
885
+ break
886
+
887
+ if read_so_far == file_size:
888
+ logger.info("Upload complete!")
889
+
890
+ def init_upload_model_version(self, model_version_proto, file_path):
891
+ file_size = os.path.getsize(file_path)
892
+ logger.debug(f"Uploading model version of model {self.model_proto.id}")
893
+ logger.debug(f"Using file '{os.path.basename(file_path)}' of size: {file_size} bytes")
894
+ result = service_pb2.PostModelVersionsUploadRequest(
895
+ upload_config=service_pb2.PostModelVersionsUploadConfig(
896
+ user_app_id=self.client.user_app_id,
897
+ model_id=self.model_proto.id,
898
+ model_version=model_version_proto,
899
+ total_size=file_size,
900
+ storage_request_size=self.storage_request_size,
901
+ is_v3=self.is_v3,
902
+ )
903
+ )
904
+ return result
905
+
906
+ def get_model_build_logs(self):
907
+ logs_request = service_pb2.ListLogEntriesRequest(
908
+ log_type="builder",
800
909
  user_app_id=self.client.user_app_id,
801
910
  model_id=self.model_proto.id,
802
- model_version=model_version_proto,
803
- total_size=file_size,
804
- storage_request_size=self.storage_request_size,
805
- is_v3=self.is_v3,
806
- ))
807
- return result
808
-
809
- def get_model_build_logs(self):
810
- logs_request = service_pb2.ListLogEntriesRequest(
811
- log_type="builder",
812
- user_app_id=self.client.user_app_id,
813
- model_id=self.model_proto.id,
814
- model_version_id=self.model_version_id,
815
- page=1,
816
- per_page=50)
817
- response = self.client.STUB.ListLogEntries(logs_request)
818
-
819
- return response
820
-
821
- def monitor_model_build(self):
822
- st = time.time()
823
- seen_logs = set() # To avoid duplicate log messages
824
- while True:
825
- resp = self.client.STUB.GetModelVersion(
826
- service_pb2.GetModelVersionRequest(
827
- user_app_id=self.client.user_app_id,
828
- model_id=self.model_proto.id,
829
- version_id=self.model_version_id,
830
- ))
831
-
832
- status_code = resp.model_version.status.code
833
- logs = self.get_model_build_logs()
834
- for log_entry in logs.log_entries:
835
- if log_entry.url not in seen_logs:
836
- seen_logs.add(log_entry.url)
837
- logger.info(f"{escape(log_entry.message.strip())}")
838
- if status_code == status_code_pb2.MODEL_BUILDING:
839
- print(f"Model is building... (elapsed {time.time() - st:.1f}s)", end='\r', flush=True)
840
-
841
- # Fetch and display the logs
842
- time.sleep(1)
843
- elif status_code == status_code_pb2.MODEL_TRAINED:
844
- logger.info("Model build complete!")
845
- logger.info(f"Build time elapsed {time.time() - st:.1f}s)")
846
- logger.info(f"Check out the model at {self.model_url} version: {self.model_version_id}")
847
- return True
848
- else:
849
- logger.info(
850
- f"\nModel build failed with status: {resp.model_version.status} and response {resp}")
851
- return False
911
+ model_version_id=self.model_version_id,
912
+ page=1,
913
+ per_page=50,
914
+ )
915
+ response = self.client.STUB.ListLogEntries(logs_request)
916
+
917
+ return response
918
+
919
+ def monitor_model_build(self):
920
+ st = time.time()
921
+ seen_logs = set() # To avoid duplicate log messages
922
+ while True:
923
+ resp = self.client.STUB.GetModelVersion(
924
+ service_pb2.GetModelVersionRequest(
925
+ user_app_id=self.client.user_app_id,
926
+ model_id=self.model_proto.id,
927
+ version_id=self.model_version_id,
928
+ )
929
+ )
930
+
931
+ status_code = resp.model_version.status.code
932
+ logs = self.get_model_build_logs()
933
+ for log_entry in logs.log_entries:
934
+ if log_entry.url not in seen_logs:
935
+ seen_logs.add(log_entry.url)
936
+ logger.info(f"{escape(log_entry.message.strip())}")
937
+ if status_code == status_code_pb2.MODEL_BUILDING:
938
+ print(
939
+ f"Model is building... (elapsed {time.time() - st:.1f}s)", end='\r', flush=True
940
+ )
941
+
942
+ # Fetch and display the logs
943
+ time.sleep(1)
944
+ elif status_code == status_code_pb2.MODEL_TRAINED:
945
+ logger.info("Model build complete!")
946
+ logger.info(f"Build time elapsed {time.time() - st:.1f}s)")
947
+ logger.info(
948
+ f"Check out the model at {self.model_url} version: {self.model_version_id}"
949
+ )
950
+ return True
951
+ else:
952
+ logger.info(
953
+ f"\nModel build failed with status: {resp.model_version.status} and response {resp}"
954
+ )
955
+ return False
852
956
 
853
957
 
854
958
  def upload_model(folder, stage, skip_dockerfile):
855
- """
856
- Uploads a model to Clarifai.
857
-
858
- :param folder: The folder containing the model files.
859
- :param stage: The stage we are calling download checkpoints from. Typically this would "upload" and will download checkpoints if config.yaml checkpoints section has when set to "upload". Other options include "runtime" to be used in load_model or "upload" to be used during model upload. Set this stage to whatever you have in config.yaml to force downloading now.
860
- :param skip_dockerfile: If True, will not create a Dockerfile.
861
- """
862
- builder = ModelBuilder(folder)
863
- builder.download_checkpoints(stage=stage)
864
- if not skip_dockerfile:
865
- builder.create_dockerfile()
866
- exists = builder.check_model_exists()
867
- if exists:
868
- logger.info(
869
- f"Model already exists at {builder.model_url}, this upload will create a new version for it."
870
- )
871
- else:
872
- logger.info(f"New model will be created at {builder.model_url} with it's first version.")
873
-
874
- input("Press Enter to continue...")
875
- builder.upload_model_version()
959
+ """
960
+ Uploads a model to Clarifai.
961
+
962
+ :param folder: The folder containing the model files.
963
+ :param stage: The stage we are calling download checkpoints from. Typically this would "upload" and will download checkpoints if config.yaml checkpoints section has when set to "upload". Other options include "runtime" to be used in load_model or "upload" to be used during model upload. Set this stage to whatever you have in config.yaml to force downloading now.
964
+ :param skip_dockerfile: If True, will not create a Dockerfile.
965
+ """
966
+ builder = ModelBuilder(folder)
967
+ builder.download_checkpoints(stage=stage)
968
+ if not skip_dockerfile:
969
+ builder.create_dockerfile()
970
+ exists = builder.check_model_exists()
971
+ if exists:
972
+ logger.info(
973
+ f"Model already exists at {builder.model_url}, this upload will create a new version for it."
974
+ )
975
+ else:
976
+ logger.info(f"New model will be created at {builder.model_url} with it's first version.")
977
+
978
+ input("Press Enter to continue...")
979
+ builder.upload_model_version()