clarifai 11.3.0rc2__py3-none-any.whl → 11.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (300) hide show
  1. clarifai/__init__.py +1 -1
  2. clarifai/cli/__main__.py +1 -1
  3. clarifai/cli/base.py +144 -136
  4. clarifai/cli/compute_cluster.py +45 -31
  5. clarifai/cli/deployment.py +93 -76
  6. clarifai/cli/model.py +578 -180
  7. clarifai/cli/nodepool.py +100 -82
  8. clarifai/client/__init__.py +12 -2
  9. clarifai/client/app.py +973 -911
  10. clarifai/client/auth/helper.py +345 -342
  11. clarifai/client/auth/register.py +7 -7
  12. clarifai/client/auth/stub.py +107 -106
  13. clarifai/client/base.py +185 -178
  14. clarifai/client/compute_cluster.py +214 -180
  15. clarifai/client/dataset.py +793 -698
  16. clarifai/client/deployment.py +55 -50
  17. clarifai/client/input.py +1223 -1088
  18. clarifai/client/lister.py +47 -45
  19. clarifai/client/model.py +1939 -1717
  20. clarifai/client/model_client.py +525 -502
  21. clarifai/client/module.py +82 -73
  22. clarifai/client/nodepool.py +358 -213
  23. clarifai/client/runner.py +58 -0
  24. clarifai/client/search.py +342 -309
  25. clarifai/client/user.py +419 -414
  26. clarifai/client/workflow.py +294 -274
  27. clarifai/constants/dataset.py +11 -17
  28. clarifai/constants/model.py +8 -2
  29. clarifai/datasets/export/inputs_annotations.py +233 -217
  30. clarifai/datasets/upload/base.py +63 -51
  31. clarifai/datasets/upload/features.py +43 -38
  32. clarifai/datasets/upload/image.py +237 -207
  33. clarifai/datasets/upload/loaders/coco_captions.py +34 -32
  34. clarifai/datasets/upload/loaders/coco_detection.py +72 -65
  35. clarifai/datasets/upload/loaders/imagenet_classification.py +57 -53
  36. clarifai/datasets/upload/loaders/xview_detection.py +274 -132
  37. clarifai/datasets/upload/multimodal.py +55 -46
  38. clarifai/datasets/upload/text.py +55 -47
  39. clarifai/datasets/upload/utils.py +250 -234
  40. clarifai/errors.py +51 -50
  41. clarifai/models/api.py +260 -238
  42. clarifai/modules/css.py +50 -50
  43. clarifai/modules/pages.py +33 -33
  44. clarifai/rag/rag.py +312 -288
  45. clarifai/rag/utils.py +91 -84
  46. clarifai/runners/models/model_builder.py +906 -802
  47. clarifai/runners/models/model_class.py +370 -331
  48. clarifai/runners/models/model_run_locally.py +459 -419
  49. clarifai/runners/models/model_runner.py +170 -162
  50. clarifai/runners/models/model_servicer.py +78 -70
  51. clarifai/runners/server.py +111 -101
  52. clarifai/runners/utils/code_script.py +225 -187
  53. clarifai/runners/utils/const.py +4 -1
  54. clarifai/runners/utils/data_types/__init__.py +12 -0
  55. clarifai/runners/utils/data_types/data_types.py +598 -0
  56. clarifai/runners/utils/data_utils.py +387 -440
  57. clarifai/runners/utils/loader.py +247 -227
  58. clarifai/runners/utils/method_signatures.py +411 -386
  59. clarifai/runners/utils/openai_convertor.py +108 -109
  60. clarifai/runners/utils/serializers.py +175 -179
  61. clarifai/runners/utils/url_fetcher.py +35 -35
  62. clarifai/schema/search.py +56 -63
  63. clarifai/urls/helper.py +125 -102
  64. clarifai/utils/cli.py +129 -123
  65. clarifai/utils/config.py +127 -87
  66. clarifai/utils/constants.py +49 -0
  67. clarifai/utils/evaluation/helpers.py +503 -466
  68. clarifai/utils/evaluation/main.py +431 -393
  69. clarifai/utils/evaluation/testset_annotation_parser.py +154 -144
  70. clarifai/utils/logging.py +324 -306
  71. clarifai/utils/misc.py +60 -56
  72. clarifai/utils/model_train.py +165 -146
  73. clarifai/utils/protobuf.py +126 -103
  74. clarifai/versions.py +3 -1
  75. clarifai/workflows/export.py +48 -50
  76. clarifai/workflows/utils.py +39 -36
  77. clarifai/workflows/validate.py +55 -43
  78. {clarifai-11.3.0rc2.dist-info → clarifai-11.4.0.dist-info}/METADATA +16 -6
  79. clarifai-11.4.0.dist-info/RECORD +109 -0
  80. {clarifai-11.3.0rc2.dist-info → clarifai-11.4.0.dist-info}/WHEEL +1 -1
  81. clarifai/__pycache__/__init__.cpython-310.pyc +0 -0
  82. clarifai/__pycache__/__init__.cpython-311.pyc +0 -0
  83. clarifai/__pycache__/__init__.cpython-39.pyc +0 -0
  84. clarifai/__pycache__/errors.cpython-310.pyc +0 -0
  85. clarifai/__pycache__/errors.cpython-311.pyc +0 -0
  86. clarifai/__pycache__/versions.cpython-310.pyc +0 -0
  87. clarifai/__pycache__/versions.cpython-311.pyc +0 -0
  88. clarifai/cli/__pycache__/__init__.cpython-310.pyc +0 -0
  89. clarifai/cli/__pycache__/__init__.cpython-311.pyc +0 -0
  90. clarifai/cli/__pycache__/base.cpython-310.pyc +0 -0
  91. clarifai/cli/__pycache__/base.cpython-311.pyc +0 -0
  92. clarifai/cli/__pycache__/base_cli.cpython-310.pyc +0 -0
  93. clarifai/cli/__pycache__/compute_cluster.cpython-310.pyc +0 -0
  94. clarifai/cli/__pycache__/compute_cluster.cpython-311.pyc +0 -0
  95. clarifai/cli/__pycache__/deployment.cpython-310.pyc +0 -0
  96. clarifai/cli/__pycache__/deployment.cpython-311.pyc +0 -0
  97. clarifai/cli/__pycache__/model.cpython-310.pyc +0 -0
  98. clarifai/cli/__pycache__/model.cpython-311.pyc +0 -0
  99. clarifai/cli/__pycache__/model_cli.cpython-310.pyc +0 -0
  100. clarifai/cli/__pycache__/nodepool.cpython-310.pyc +0 -0
  101. clarifai/cli/__pycache__/nodepool.cpython-311.pyc +0 -0
  102. clarifai/client/__pycache__/__init__.cpython-310.pyc +0 -0
  103. clarifai/client/__pycache__/__init__.cpython-311.pyc +0 -0
  104. clarifai/client/__pycache__/__init__.cpython-39.pyc +0 -0
  105. clarifai/client/__pycache__/app.cpython-310.pyc +0 -0
  106. clarifai/client/__pycache__/app.cpython-311.pyc +0 -0
  107. clarifai/client/__pycache__/app.cpython-39.pyc +0 -0
  108. clarifai/client/__pycache__/base.cpython-310.pyc +0 -0
  109. clarifai/client/__pycache__/base.cpython-311.pyc +0 -0
  110. clarifai/client/__pycache__/compute_cluster.cpython-310.pyc +0 -0
  111. clarifai/client/__pycache__/compute_cluster.cpython-311.pyc +0 -0
  112. clarifai/client/__pycache__/dataset.cpython-310.pyc +0 -0
  113. clarifai/client/__pycache__/dataset.cpython-311.pyc +0 -0
  114. clarifai/client/__pycache__/deployment.cpython-310.pyc +0 -0
  115. clarifai/client/__pycache__/deployment.cpython-311.pyc +0 -0
  116. clarifai/client/__pycache__/input.cpython-310.pyc +0 -0
  117. clarifai/client/__pycache__/input.cpython-311.pyc +0 -0
  118. clarifai/client/__pycache__/lister.cpython-310.pyc +0 -0
  119. clarifai/client/__pycache__/lister.cpython-311.pyc +0 -0
  120. clarifai/client/__pycache__/model.cpython-310.pyc +0 -0
  121. clarifai/client/__pycache__/model.cpython-311.pyc +0 -0
  122. clarifai/client/__pycache__/module.cpython-310.pyc +0 -0
  123. clarifai/client/__pycache__/module.cpython-311.pyc +0 -0
  124. clarifai/client/__pycache__/nodepool.cpython-310.pyc +0 -0
  125. clarifai/client/__pycache__/nodepool.cpython-311.pyc +0 -0
  126. clarifai/client/__pycache__/search.cpython-310.pyc +0 -0
  127. clarifai/client/__pycache__/search.cpython-311.pyc +0 -0
  128. clarifai/client/__pycache__/user.cpython-310.pyc +0 -0
  129. clarifai/client/__pycache__/user.cpython-311.pyc +0 -0
  130. clarifai/client/__pycache__/workflow.cpython-310.pyc +0 -0
  131. clarifai/client/__pycache__/workflow.cpython-311.pyc +0 -0
  132. clarifai/client/auth/__pycache__/__init__.cpython-310.pyc +0 -0
  133. clarifai/client/auth/__pycache__/__init__.cpython-311.pyc +0 -0
  134. clarifai/client/auth/__pycache__/helper.cpython-310.pyc +0 -0
  135. clarifai/client/auth/__pycache__/helper.cpython-311.pyc +0 -0
  136. clarifai/client/auth/__pycache__/register.cpython-310.pyc +0 -0
  137. clarifai/client/auth/__pycache__/register.cpython-311.pyc +0 -0
  138. clarifai/client/auth/__pycache__/stub.cpython-310.pyc +0 -0
  139. clarifai/client/auth/__pycache__/stub.cpython-311.pyc +0 -0
  140. clarifai/client/cli/__init__.py +0 -0
  141. clarifai/client/cli/__pycache__/__init__.cpython-310.pyc +0 -0
  142. clarifai/client/cli/__pycache__/base_cli.cpython-310.pyc +0 -0
  143. clarifai/client/cli/__pycache__/model_cli.cpython-310.pyc +0 -0
  144. clarifai/client/cli/base_cli.py +0 -88
  145. clarifai/client/cli/model_cli.py +0 -29
  146. clarifai/constants/__pycache__/base.cpython-310.pyc +0 -0
  147. clarifai/constants/__pycache__/base.cpython-311.pyc +0 -0
  148. clarifai/constants/__pycache__/dataset.cpython-310.pyc +0 -0
  149. clarifai/constants/__pycache__/dataset.cpython-311.pyc +0 -0
  150. clarifai/constants/__pycache__/input.cpython-310.pyc +0 -0
  151. clarifai/constants/__pycache__/input.cpython-311.pyc +0 -0
  152. clarifai/constants/__pycache__/model.cpython-310.pyc +0 -0
  153. clarifai/constants/__pycache__/model.cpython-311.pyc +0 -0
  154. clarifai/constants/__pycache__/rag.cpython-310.pyc +0 -0
  155. clarifai/constants/__pycache__/rag.cpython-311.pyc +0 -0
  156. clarifai/constants/__pycache__/search.cpython-310.pyc +0 -0
  157. clarifai/constants/__pycache__/search.cpython-311.pyc +0 -0
  158. clarifai/constants/__pycache__/workflow.cpython-310.pyc +0 -0
  159. clarifai/constants/__pycache__/workflow.cpython-311.pyc +0 -0
  160. clarifai/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
  161. clarifai/datasets/__pycache__/__init__.cpython-311.pyc +0 -0
  162. clarifai/datasets/__pycache__/__init__.cpython-39.pyc +0 -0
  163. clarifai/datasets/export/__pycache__/__init__.cpython-310.pyc +0 -0
  164. clarifai/datasets/export/__pycache__/__init__.cpython-311.pyc +0 -0
  165. clarifai/datasets/export/__pycache__/__init__.cpython-39.pyc +0 -0
  166. clarifai/datasets/export/__pycache__/inputs_annotations.cpython-310.pyc +0 -0
  167. clarifai/datasets/export/__pycache__/inputs_annotations.cpython-311.pyc +0 -0
  168. clarifai/datasets/upload/__pycache__/__init__.cpython-310.pyc +0 -0
  169. clarifai/datasets/upload/__pycache__/__init__.cpython-311.pyc +0 -0
  170. clarifai/datasets/upload/__pycache__/__init__.cpython-39.pyc +0 -0
  171. clarifai/datasets/upload/__pycache__/base.cpython-310.pyc +0 -0
  172. clarifai/datasets/upload/__pycache__/base.cpython-311.pyc +0 -0
  173. clarifai/datasets/upload/__pycache__/features.cpython-310.pyc +0 -0
  174. clarifai/datasets/upload/__pycache__/features.cpython-311.pyc +0 -0
  175. clarifai/datasets/upload/__pycache__/image.cpython-310.pyc +0 -0
  176. clarifai/datasets/upload/__pycache__/image.cpython-311.pyc +0 -0
  177. clarifai/datasets/upload/__pycache__/multimodal.cpython-310.pyc +0 -0
  178. clarifai/datasets/upload/__pycache__/multimodal.cpython-311.pyc +0 -0
  179. clarifai/datasets/upload/__pycache__/text.cpython-310.pyc +0 -0
  180. clarifai/datasets/upload/__pycache__/text.cpython-311.pyc +0 -0
  181. clarifai/datasets/upload/__pycache__/utils.cpython-310.pyc +0 -0
  182. clarifai/datasets/upload/__pycache__/utils.cpython-311.pyc +0 -0
  183. clarifai/datasets/upload/loaders/__pycache__/__init__.cpython-311.pyc +0 -0
  184. clarifai/datasets/upload/loaders/__pycache__/__init__.cpython-39.pyc +0 -0
  185. clarifai/datasets/upload/loaders/__pycache__/coco_detection.cpython-311.pyc +0 -0
  186. clarifai/datasets/upload/loaders/__pycache__/imagenet_classification.cpython-311.pyc +0 -0
  187. clarifai/models/__pycache__/__init__.cpython-39.pyc +0 -0
  188. clarifai/modules/__pycache__/__init__.cpython-39.pyc +0 -0
  189. clarifai/rag/__pycache__/__init__.cpython-310.pyc +0 -0
  190. clarifai/rag/__pycache__/__init__.cpython-311.pyc +0 -0
  191. clarifai/rag/__pycache__/__init__.cpython-39.pyc +0 -0
  192. clarifai/rag/__pycache__/rag.cpython-310.pyc +0 -0
  193. clarifai/rag/__pycache__/rag.cpython-311.pyc +0 -0
  194. clarifai/rag/__pycache__/rag.cpython-39.pyc +0 -0
  195. clarifai/rag/__pycache__/utils.cpython-310.pyc +0 -0
  196. clarifai/rag/__pycache__/utils.cpython-311.pyc +0 -0
  197. clarifai/runners/__pycache__/__init__.cpython-310.pyc +0 -0
  198. clarifai/runners/__pycache__/__init__.cpython-311.pyc +0 -0
  199. clarifai/runners/__pycache__/__init__.cpython-39.pyc +0 -0
  200. clarifai/runners/dockerfile_template/Dockerfile.cpu.template +0 -31
  201. clarifai/runners/dockerfile_template/Dockerfile.cuda.template +0 -42
  202. clarifai/runners/dockerfile_template/Dockerfile.nim +0 -71
  203. clarifai/runners/models/__pycache__/__init__.cpython-310.pyc +0 -0
  204. clarifai/runners/models/__pycache__/__init__.cpython-311.pyc +0 -0
  205. clarifai/runners/models/__pycache__/__init__.cpython-39.pyc +0 -0
  206. clarifai/runners/models/__pycache__/base_typed_model.cpython-310.pyc +0 -0
  207. clarifai/runners/models/__pycache__/base_typed_model.cpython-311.pyc +0 -0
  208. clarifai/runners/models/__pycache__/base_typed_model.cpython-39.pyc +0 -0
  209. clarifai/runners/models/__pycache__/model_builder.cpython-311.pyc +0 -0
  210. clarifai/runners/models/__pycache__/model_class.cpython-310.pyc +0 -0
  211. clarifai/runners/models/__pycache__/model_class.cpython-311.pyc +0 -0
  212. clarifai/runners/models/__pycache__/model_run_locally.cpython-310-pytest-7.1.2.pyc +0 -0
  213. clarifai/runners/models/__pycache__/model_run_locally.cpython-310.pyc +0 -0
  214. clarifai/runners/models/__pycache__/model_run_locally.cpython-311.pyc +0 -0
  215. clarifai/runners/models/__pycache__/model_runner.cpython-310.pyc +0 -0
  216. clarifai/runners/models/__pycache__/model_runner.cpython-311.pyc +0 -0
  217. clarifai/runners/models/__pycache__/model_upload.cpython-310.pyc +0 -0
  218. clarifai/runners/models/base_typed_model.py +0 -238
  219. clarifai/runners/models/model_class_refract.py +0 -80
  220. clarifai/runners/models/model_upload.py +0 -607
  221. clarifai/runners/models/temp.py +0 -25
  222. clarifai/runners/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  223. clarifai/runners/utils/__pycache__/__init__.cpython-311.pyc +0 -0
  224. clarifai/runners/utils/__pycache__/__init__.cpython-38.pyc +0 -0
  225. clarifai/runners/utils/__pycache__/__init__.cpython-39.pyc +0 -0
  226. clarifai/runners/utils/__pycache__/buffered_stream.cpython-310.pyc +0 -0
  227. clarifai/runners/utils/__pycache__/buffered_stream.cpython-38.pyc +0 -0
  228. clarifai/runners/utils/__pycache__/buffered_stream.cpython-39.pyc +0 -0
  229. clarifai/runners/utils/__pycache__/const.cpython-310.pyc +0 -0
  230. clarifai/runners/utils/__pycache__/const.cpython-311.pyc +0 -0
  231. clarifai/runners/utils/__pycache__/constants.cpython-310.pyc +0 -0
  232. clarifai/runners/utils/__pycache__/constants.cpython-38.pyc +0 -0
  233. clarifai/runners/utils/__pycache__/constants.cpython-39.pyc +0 -0
  234. clarifai/runners/utils/__pycache__/data_handler.cpython-310.pyc +0 -0
  235. clarifai/runners/utils/__pycache__/data_handler.cpython-311.pyc +0 -0
  236. clarifai/runners/utils/__pycache__/data_handler.cpython-38.pyc +0 -0
  237. clarifai/runners/utils/__pycache__/data_handler.cpython-39.pyc +0 -0
  238. clarifai/runners/utils/__pycache__/data_utils.cpython-310.pyc +0 -0
  239. clarifai/runners/utils/__pycache__/data_utils.cpython-311.pyc +0 -0
  240. clarifai/runners/utils/__pycache__/data_utils.cpython-38.pyc +0 -0
  241. clarifai/runners/utils/__pycache__/data_utils.cpython-39.pyc +0 -0
  242. clarifai/runners/utils/__pycache__/grpc_server.cpython-310.pyc +0 -0
  243. clarifai/runners/utils/__pycache__/grpc_server.cpython-38.pyc +0 -0
  244. clarifai/runners/utils/__pycache__/grpc_server.cpython-39.pyc +0 -0
  245. clarifai/runners/utils/__pycache__/health.cpython-310.pyc +0 -0
  246. clarifai/runners/utils/__pycache__/health.cpython-38.pyc +0 -0
  247. clarifai/runners/utils/__pycache__/health.cpython-39.pyc +0 -0
  248. clarifai/runners/utils/__pycache__/loader.cpython-310.pyc +0 -0
  249. clarifai/runners/utils/__pycache__/loader.cpython-311.pyc +0 -0
  250. clarifai/runners/utils/__pycache__/logging.cpython-310.pyc +0 -0
  251. clarifai/runners/utils/__pycache__/logging.cpython-38.pyc +0 -0
  252. clarifai/runners/utils/__pycache__/logging.cpython-39.pyc +0 -0
  253. clarifai/runners/utils/__pycache__/stream_source.cpython-310.pyc +0 -0
  254. clarifai/runners/utils/__pycache__/stream_source.cpython-39.pyc +0 -0
  255. clarifai/runners/utils/__pycache__/url_fetcher.cpython-310.pyc +0 -0
  256. clarifai/runners/utils/__pycache__/url_fetcher.cpython-311.pyc +0 -0
  257. clarifai/runners/utils/__pycache__/url_fetcher.cpython-38.pyc +0 -0
  258. clarifai/runners/utils/__pycache__/url_fetcher.cpython-39.pyc +0 -0
  259. clarifai/runners/utils/data_handler.py +0 -231
  260. clarifai/runners/utils/data_handler_refract.py +0 -213
  261. clarifai/runners/utils/data_types.py +0 -469
  262. clarifai/runners/utils/logger.py +0 -0
  263. clarifai/runners/utils/openai_format.py +0 -87
  264. clarifai/schema/__pycache__/search.cpython-310.pyc +0 -0
  265. clarifai/schema/__pycache__/search.cpython-311.pyc +0 -0
  266. clarifai/urls/__pycache__/helper.cpython-310.pyc +0 -0
  267. clarifai/urls/__pycache__/helper.cpython-311.pyc +0 -0
  268. clarifai/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  269. clarifai/utils/__pycache__/__init__.cpython-311.pyc +0 -0
  270. clarifai/utils/__pycache__/__init__.cpython-39.pyc +0 -0
  271. clarifai/utils/__pycache__/cli.cpython-310.pyc +0 -0
  272. clarifai/utils/__pycache__/cli.cpython-311.pyc +0 -0
  273. clarifai/utils/__pycache__/config.cpython-311.pyc +0 -0
  274. clarifai/utils/__pycache__/constants.cpython-310.pyc +0 -0
  275. clarifai/utils/__pycache__/constants.cpython-311.pyc +0 -0
  276. clarifai/utils/__pycache__/logging.cpython-310.pyc +0 -0
  277. clarifai/utils/__pycache__/logging.cpython-311.pyc +0 -0
  278. clarifai/utils/__pycache__/misc.cpython-310.pyc +0 -0
  279. clarifai/utils/__pycache__/misc.cpython-311.pyc +0 -0
  280. clarifai/utils/__pycache__/model_train.cpython-310.pyc +0 -0
  281. clarifai/utils/__pycache__/model_train.cpython-311.pyc +0 -0
  282. clarifai/utils/__pycache__/protobuf.cpython-311.pyc +0 -0
  283. clarifai/utils/evaluation/__pycache__/__init__.cpython-311.pyc +0 -0
  284. clarifai/utils/evaluation/__pycache__/__init__.cpython-39.pyc +0 -0
  285. clarifai/utils/evaluation/__pycache__/helpers.cpython-311.pyc +0 -0
  286. clarifai/utils/evaluation/__pycache__/main.cpython-311.pyc +0 -0
  287. clarifai/utils/evaluation/__pycache__/main.cpython-39.pyc +0 -0
  288. clarifai/workflows/__pycache__/__init__.cpython-310.pyc +0 -0
  289. clarifai/workflows/__pycache__/__init__.cpython-311.pyc +0 -0
  290. clarifai/workflows/__pycache__/__init__.cpython-39.pyc +0 -0
  291. clarifai/workflows/__pycache__/export.cpython-310.pyc +0 -0
  292. clarifai/workflows/__pycache__/export.cpython-311.pyc +0 -0
  293. clarifai/workflows/__pycache__/utils.cpython-310.pyc +0 -0
  294. clarifai/workflows/__pycache__/utils.cpython-311.pyc +0 -0
  295. clarifai/workflows/__pycache__/validate.cpython-310.pyc +0 -0
  296. clarifai/workflows/__pycache__/validate.cpython-311.pyc +0 -0
  297. clarifai-11.3.0rc2.dist-info/RECORD +0 -322
  298. {clarifai-11.3.0rc2.dist-info → clarifai-11.4.0.dist-info}/entry_points.txt +0 -0
  299. {clarifai-11.3.0rc2.dist-info → clarifai-11.4.0.dist-info/licenses}/LICENSE +0 -0
  300. {clarifai-11.3.0rc2.dist-info → clarifai-11.4.0.dist-info}/top_level.txt +0 -0
@@ -9,20 +9,24 @@ from clarifai.constants.model import MAX_MODEL_PREDICT_INPUTS
9
9
  from clarifai.errors import UserError
10
10
  from clarifai.runners.utils import code_script, method_signatures
11
11
  from clarifai.runners.utils.data_utils import is_openai_chat_format
12
- from clarifai.runners.utils.method_signatures import (CompatibilitySerializer, deserialize,
13
- get_stream_from_signature, serialize,
14
- signatures_from_json)
12
+ from clarifai.runners.utils.method_signatures import (
13
+ CompatibilitySerializer,
14
+ deserialize,
15
+ get_stream_from_signature,
16
+ serialize,
17
+ signatures_from_json,
18
+ )
15
19
  from clarifai.utils.logging import logger
16
20
  from clarifai.utils.misc import BackoffIterator, status_is_retryable
17
21
 
18
22
 
19
23
  class ModelClient:
20
- '''
21
- Client for calling model predict, generate, and stream methods.
22
- '''
23
-
24
- def __init__(self, stub, request_template: service_pb2.PostModelOutputsRequest = None):
25
24
  '''
25
+ Client for calling model predict, generate, and stream methods.
26
+ '''
27
+
28
+ def __init__(self, stub, request_template: service_pb2.PostModelOutputsRequest = None):
29
+ '''
26
30
  Initialize the model client.
27
31
 
28
32
  Args:
@@ -30,501 +34,520 @@ class ModelClient:
30
34
  request_template: The template for the request to send to the model, including
31
35
  common fields like model_id, model_version, cluster, etc.
32
36
  '''
33
- self.STUB = stub
34
- self.request_template = request_template or service_pb2.PostModelOutputsRequest()
35
- self._method_signatures = None
36
- self._defined = False
37
+ self.STUB = stub
38
+ self.request_template = request_template or service_pb2.PostModelOutputsRequest()
39
+ self._method_signatures = None
40
+ self._defined = False
37
41
 
38
- def fetch(self):
39
- '''
40
- Fetch function signature definitions from the model and define the functions in the client
41
- '''
42
- if self._defined:
43
- return
44
- try:
45
- self._fetch_signatures()
46
- self._define_functions()
47
- finally:
48
- self._defined = True
49
-
50
- def __getattr__(self, name):
51
- if not self._defined:
52
- self.fetch()
53
- return self.__getattribute__(name)
54
-
55
- def _fetch_signatures(self):
56
- '''
57
- Fetch the method signatures from the model.
58
-
59
- Returns:
60
- Dict: The method signatures.
61
- '''
62
- try:
63
- response = self.STUB.GetModelVersion(
64
- service_pb2.GetModelVersionRequest(
65
- user_app_id=self.request_template.user_app_id,
66
- model_id=self.request_template.model_id,
67
- version_id=self.request_template.version_id,
68
- ))
69
-
70
- method_signatures = None
71
- if response.status.code == status_code_pb2.SUCCESS:
72
- method_signatures = response.model_version.method_signatures
73
- if response.status.code != status_code_pb2.SUCCESS:
74
- raise Exception(f"Model failed with response {response!r}")
75
- self._method_signatures = {}
76
- for method_signature in method_signatures:
77
- method_name = method_signature.name
78
- # check for duplicate method names
79
- if method_name in self._method_signatures:
80
- raise ValueError(f"Duplicate method name {method_name}")
81
- self._method_signatures[method_name] = method_signature
82
- if not self._method_signatures: # if no method signatures, try to fetch from the model
83
- self._fetch_signatures_backup()
84
- except Exception:
85
- # try to fetch from the model
86
- self._fetch_signatures_backup()
87
- if not self._method_signatures:
88
- raise ValueError("Failed to fetch method signatures from model and backup method")
89
-
90
- def _fetch_signatures_backup(self):
91
- '''
92
- This is a temporary method of fetching the method signatures from the model.
93
-
94
- Returns:
95
- Dict: The method signatures.
96
- '''
97
-
98
- request = service_pb2.PostModelOutputsRequest()
99
- request.CopyFrom(self.request_template)
100
- # request.model.model_version.output_info.params['_method_name'] = '_GET_SIGNATURES'
101
- inp = request.inputs.add() # empty input for this method
102
- inp.data.parts.add() # empty part for this input
103
- inp.data.metadata['_method_name'] = '_GET_SIGNATURES'
104
- start_time = time.time()
105
- backoff_iterator = BackoffIterator(10)
106
- while True:
107
- response = self.STUB.PostModelOutputs(request)
108
- if status_is_retryable(
109
- response.status.code) and time.time() - start_time < 60 * 10: # 10 minutes
110
- logger.info(f"Retrying model info fetch with response {response.status!r}")
111
- time.sleep(next(backoff_iterator))
112
- continue
113
- break
114
- if (response.status.code == status_code_pb2.INPUT_UNSUPPORTED_FORMAT or
115
- (response.status.code == status_code_pb2.SUCCESS and
116
- response.outputs[0].data.text.raw == '')):
117
- # return codes/values from older models that don't support _GET_SIGNATURES
118
- self._method_signatures = {}
119
- self._define_compatability_functions()
120
- return
121
- if response.status.code != status_code_pb2.SUCCESS:
122
- raise Exception(f"Model failed with response {response!r}")
123
- self._method_signatures = signatures_from_json(response.outputs[0].data.text.raw)
124
-
125
- def _define_functions(self):
126
- '''
127
- Define the functions based on the method signatures.
128
- '''
129
- for method_name, method_signature in self._method_signatures.items():
130
- # define the function in this client instance
131
- if resources_pb2.RunnerMethodType.Name(method_signature.method_type) == 'UNARY_UNARY':
132
- call_func = self._predict
133
- elif resources_pb2.RunnerMethodType.Name(method_signature.method_type) == 'UNARY_STREAMING':
134
- call_func = self._generate
135
- elif resources_pb2.RunnerMethodType.Name(
136
- method_signature.method_type) == 'STREAMING_STREAMING':
137
- call_func = self._stream
138
- else:
139
- raise ValueError(f"Unknown method type {method_signature.method_type}")
140
-
141
- # method argnames, in order, collapsing nested keys to corresponding user function args
142
- method_argnames = []
143
- for var in method_signature.input_fields:
144
- outer = var.name.split('.', 1)[0]
145
- if outer in method_argnames:
146
- continue
147
- method_argnames.append(outer)
148
-
149
- def bind_f(method_name, method_argnames, call_func):
150
-
151
- def f(*args, **kwargs):
152
- if len(args) > len(method_argnames):
153
- raise TypeError(
154
- f"{method_name}() takes {len(method_argnames)} positional arguments but {len(args)} were given"
42
+ def fetch(self):
43
+ '''
44
+ Fetch function signature definitions from the model and define the functions in the client
45
+ '''
46
+ if self._defined:
47
+ return
48
+ try:
49
+ self._fetch_signatures()
50
+ self._define_functions()
51
+ finally:
52
+ self._defined = True
53
+
54
+ def __getattr__(self, name):
55
+ if not self._defined:
56
+ self.fetch()
57
+ return self.__getattribute__(name)
58
+
59
+ def _fetch_signatures(self):
60
+ '''
61
+ Fetch the method signatures from the model.
62
+
63
+ Returns:
64
+ Dict: The method signatures.
65
+ '''
66
+ try:
67
+ response = self.STUB.GetModelVersion(
68
+ service_pb2.GetModelVersionRequest(
69
+ user_app_id=self.request_template.user_app_id,
70
+ model_id=self.request_template.model_id,
71
+ version_id=self.request_template.version_id,
72
+ )
155
73
  )
156
74
 
157
- if len(args) + len(kwargs) > len(method_argnames):
158
- raise TypeError(
159
- f"{method_name}() got an unexpected keyword argument {next(iter(kwargs))}")
160
- if len(args) == 1 and (not kwargs) and isinstance(args[0], list):
161
- batch_inputs = args[0]
162
- # Validate each input is a dictionary
163
- is_batch_input_valid = all(isinstance(input, dict) for input in batch_inputs)
164
- if is_batch_input_valid and (not is_openai_chat_format(batch_inputs)):
165
- # If the batch input is valid, call the function with the batch inputs and the method name
166
- return call_func(batch_inputs, method_name)
167
-
168
- for name, arg in zip(method_argnames, args): # handle positional with zip shortest
169
- if name in kwargs:
170
- raise TypeError(f"Multiple values for argument {name}")
171
- kwargs[name] = arg
172
- return call_func(kwargs, method_name)
173
-
174
- return f
175
-
176
- # need to bind method_name to the value, not the mutating loop variable
177
- f = bind_f(method_name, method_argnames, call_func)
178
-
179
- # set names, annotations and docstrings
180
- f.__name__ = method_name
181
- f.__qualname__ = f'{self.__class__.__name__}.{method_name}'
182
- input_annotations = code_script._get_annotations_source(method_signature)
183
- return_annotation = input_annotations.pop('return', (None, None))[0]
184
- sig = inspect.signature(f).replace(
185
- parameters=[
186
- inspect.Parameter(k, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=v[0])
187
- for k, v in input_annotations.items()
188
- ],
189
- return_annotation=return_annotation,
190
- )
191
- f.__signature__ = sig
192
- f.__doc__ = method_signature.description
193
- setattr(self, method_name, f)
194
-
195
- def available_methods(self) -> List[str]:
196
- """Get the available methods for this model.
197
-
198
- Returns:
199
- List[str]: The available methods.
200
- """
201
- if not self._defined:
202
- self.fetch()
203
- return self._method_signatures.keys()
204
-
205
- def method_signature(self, method_name: str) -> str:
206
- """Get the method signature for a method.
207
-
208
- Args:
209
- method_name (str): The name of the method.
210
-
211
- Returns:
212
- str: The method signature.
213
- """
214
- if not self._defined:
215
- self.fetch()
216
- return method_signatures.get_method_signature(self._method_signatures[method_name])
217
-
218
- def generate_client_script(self) -> str:
219
- """Generate a client script for this model.
220
-
221
- Returns:
222
- str: The client script.
223
- """
224
- if not self._defined:
225
- self.fetch()
226
- method_signatures = []
227
- for _, method_signature in self._method_signatures.items():
228
- method_signatures.append(method_signature)
229
- return code_script.generate_client_script(
230
- method_signatures,
231
- user_id=self.request_template.user_app_id.user_id,
232
- app_id=self.request_template.user_app_id.app_id,
233
- model_id=self.request_template.model_id)
234
-
235
- def _define_compatability_functions(self):
236
-
237
- serializer = CompatibilitySerializer()
238
-
239
- def predict(input: Any) -> Any:
240
- proto = resources_pb2.Input()
241
- serializer.serialize(proto.data, input)
242
- # always use text.raw for compat
243
- if proto.data.string_value:
244
- proto.data.text.raw = proto.data.string_value
245
- proto.data.string_value = ''
246
- response = self._predict_by_proto([proto])
247
- if response.status.code != status_code_pb2.SUCCESS:
248
- raise Exception(f"Model predict failed with response {response!r}")
249
- response_data = response.outputs[0].data
250
- if response_data.text.raw:
251
- response_data.string_value = response_data.text.raw
252
- response_data.text.raw = ''
253
- return serializer.deserialize(response_data)
254
-
255
- self.predict = predict
256
-
257
- def _predict(
258
- self,
259
- inputs, # TODO set up functions according to fetched signatures?
260
- method_name: str = 'predict',
261
- ) -> Any:
262
- input_signature = self._method_signatures[method_name].input_fields
263
- output_signature = self._method_signatures[method_name].output_fields
264
-
265
- batch_input = True
266
- if isinstance(inputs, dict):
267
- inputs = [inputs]
268
- batch_input = False
269
-
270
- proto_inputs = []
271
- for input in inputs:
272
- proto = resources_pb2.Input()
273
-
274
- serialize(input, input_signature, proto.data)
275
- proto_inputs.append(proto)
276
-
277
- response = self._predict_by_proto(proto_inputs, method_name)
278
-
279
- outputs = []
280
- for output in response.outputs:
281
- outputs.append(deserialize(output.data, output_signature, is_output=True))
282
- if batch_input:
283
- return outputs
284
- return outputs[0]
285
-
286
- def _predict_by_proto(
287
- self,
288
- inputs: List[resources_pb2.Input],
289
- method_name: str = None,
290
- inference_params: Dict = None,
291
- output_config: Dict = None,
292
- ) -> service_pb2.MultiOutputResponse:
293
- """Predicts the model based on the given inputs.
294
-
295
- Args:
296
- inputs (List[resources_pb2.Input]): The inputs to predict.
297
- method_name (str): The remote method name to call.
298
- inference_params (Dict): Inference parameters to override.
299
- output_config (Dict): Output configuration to override.
300
-
301
- Returns:
302
- service_pb2.MultiOutputResponse: The prediction response(s).
303
- """
304
- if not isinstance(inputs, list):
305
- raise UserError('Invalid inputs, inputs must be a list of Input objects.')
306
- if len(inputs) > MAX_MODEL_PREDICT_INPUTS:
307
- raise UserError(f"Too many inputs. Max is {MAX_MODEL_PREDICT_INPUTS}.")
308
-
309
- request = service_pb2.PostModelOutputsRequest()
310
- request.CopyFrom(self.request_template)
311
-
312
- request.inputs.extend(inputs)
313
-
314
- if method_name:
315
- # TODO put in new proto field?
316
- for inp in request.inputs:
317
- inp.data.metadata['_method_name'] = method_name
318
- if inference_params:
319
- request.model.model_version.output_info.params.update(inference_params)
320
- if output_config:
321
- request.model.model_version.output_info.output_config.MergeFrom(
322
- resources_pb2.OutputConfig(**output_config))
323
-
324
- start_time = time.time()
325
- backoff_iterator = BackoffIterator(10)
326
- while True:
327
- response = self.STUB.PostModelOutputs(request)
328
- if status_is_retryable(
329
- response.status.code) and time.time() - start_time < 60 * 10: # 10 minutes
330
- logger.info("Model is still deploying, please wait...")
331
- time.sleep(next(backoff_iterator))
332
- continue
333
-
334
- if response.status.code != status_code_pb2.SUCCESS:
335
- raise Exception(f"Model predict failed with response {response!r}")
336
- break
337
- return response
338
-
339
- def _generate(
340
- self,
341
- inputs, # TODO set up functions according to fetched signatures?
342
- method_name: str = 'generate',
343
- ) -> Any:
344
- input_signature = self._method_signatures[method_name].input_fields
345
- output_signature = self._method_signatures[method_name].output_fields
346
-
347
- batch_input = True
348
- if isinstance(inputs, dict):
349
- inputs = [inputs]
350
- batch_input = False
351
-
352
- proto_inputs = []
353
- for input in inputs:
354
- proto = resources_pb2.Input()
355
- serialize(input, input_signature, proto.data)
356
- proto_inputs.append(proto)
357
-
358
- response_stream = self._generate_by_proto(proto_inputs, method_name)
359
-
360
- for response in response_stream:
361
- outputs = []
362
- for output in response.outputs:
363
- outputs.append(deserialize(output.data, output_signature, is_output=True))
364
- if batch_input:
365
- yield outputs
366
- else:
367
- yield outputs[0]
368
-
369
- def _generate_by_proto(
370
- self,
371
- inputs: List[resources_pb2.Input],
372
- method_name: str = None,
373
- inference_params: Dict = {},
374
- output_config: Dict = {},
375
- ):
376
- """Generate the stream output on model based on the given inputs.
377
-
378
- Args:
379
- inputs (list[Input]): The inputs to generate, must be less than 128.
380
- method_name (str): The remote method name to call.
381
- inference_params (dict): The inference params to override.
382
- output_config (dict): The output config to override.
383
- """
384
- if not isinstance(inputs, list):
385
- raise UserError('Invalid inputs, inputs must be a list of Input objects.')
386
- if len(inputs) > MAX_MODEL_PREDICT_INPUTS:
387
- raise UserError(f"Too many inputs. Max is {MAX_MODEL_PREDICT_INPUTS}."
388
- ) # TODO Use Chunker for inputs len > 128
389
-
390
- request = service_pb2.PostModelOutputsRequest()
391
- request.CopyFrom(self.request_template)
392
-
393
- request.inputs.extend(inputs)
394
-
395
- if method_name:
396
- # TODO put in new proto field?
397
- for inp in request.inputs:
398
- inp.data.metadata['_method_name'] = method_name
399
- if inference_params:
400
- request.model.model_version.output_info.params.update(inference_params)
401
- if output_config:
402
- request.model.model_version.output_info.output_config.MergeFromDict(output_config)
403
-
404
- start_time = time.time()
405
- backoff_iterator = BackoffIterator(10)
406
- started = False
407
- while not started:
408
- stream_response = self.STUB.GenerateModelOutputs(request)
409
- try:
410
- response = next(stream_response) # get the first response
411
- except StopIteration:
412
- raise Exception("Model Generate failed with no response")
413
- if status_is_retryable(response.status.code) and \
414
- time.time() - start_time < 60 * 10:
415
- logger.info("Model is still deploying, please wait...")
416
- time.sleep(next(backoff_iterator))
417
- continue
418
- if response.status.code != status_code_pb2.SUCCESS:
419
- raise Exception(f"Model Generate failed with response {response.status!r}")
420
- started = True
421
-
422
- yield response # yield the first response
423
-
424
- for response in stream_response:
425
- if response.status.code != status_code_pb2.SUCCESS:
426
- raise Exception(f"Model Generate failed with response {response.status!r}")
427
- yield response
428
-
429
- def _stream(
430
- self,
431
- inputs,
432
- method_name: str = 'stream',
433
- ) -> Any:
434
- input_signature = self._method_signatures[method_name].input_fields
435
- output_signature = self._method_signatures[method_name].output_fields
436
-
437
- if isinstance(inputs, list):
438
- assert len(inputs) == 1, 'streaming methods do not support batched calls'
439
- inputs = inputs[0]
440
- assert isinstance(inputs, dict)
441
- kwargs = inputs
442
-
443
- # find the streaming vars in the input signature, and the streaming input python param
444
- stream_sig = get_stream_from_signature(input_signature)
445
- if stream_sig is None:
446
- raise ValueError("Streaming method must have a Stream input")
447
- stream_argname = stream_sig.name
448
-
449
- # get the streaming input generator from the user-provided function arg values
450
- user_inputs_generator = kwargs.pop(stream_argname)
451
-
452
- def _input_proto_stream():
453
- # first item contains all the inputs and the first stream item
454
- proto = resources_pb2.Input()
455
- try:
456
- item = next(user_inputs_generator)
457
- except StopIteration:
458
- return # no items to stream
459
- kwargs[stream_argname] = item
460
- serialize(kwargs, input_signature, proto.data)
461
-
462
- yield proto
463
-
464
- # subsequent items are just the stream items
465
- for item in user_inputs_generator:
466
- proto = resources_pb2.Input()
467
- serialize({stream_argname: item}, [stream_sig], proto.data)
468
- yield proto
469
-
470
- response_stream = self._stream_by_proto(_input_proto_stream(), method_name)
471
-
472
- for response in response_stream:
473
- assert len(response.outputs) == 1, 'streaming methods must have exactly one output'
474
- yield deserialize(response.outputs[0].data, output_signature, is_output=True)
475
-
476
- def _req_iterator(self,
477
- input_iterator: Iterator[List[resources_pb2.Input]],
478
- method_name: str = None,
479
- inference_params: Dict = {},
480
- output_config: Dict = {}):
481
- request = service_pb2.PostModelOutputsRequest()
482
- request.CopyFrom(self.request_template)
483
- if inference_params:
484
- request.model.model_version.output_info.params.update(inference_params)
485
- if output_config:
486
- request.model.model_version.output_info.output_config.MergeFromDict(output_config)
487
- for inputs in input_iterator:
488
- req = service_pb2.PostModelOutputsRequest()
489
- req.CopyFrom(request)
490
- if isinstance(inputs, list):
491
- req.inputs.extend(inputs)
492
- else:
493
- req.inputs.append(inputs)
494
- # TODO: put into new proto field?
495
- if method_name:
496
- for inp in req.inputs:
497
- inp.data.metadata['_method_name'] = method_name
498
- yield req
499
-
500
- def _stream_by_proto(self,
501
- inputs: Iterator[List[resources_pb2.Input]],
502
- method_name: str = None,
503
- inference_params: Dict = {},
504
- output_config: Dict = {}):
505
- """Generate the stream output on model based on the given stream of inputs.
506
- """
507
- # if not isinstance(inputs, Iterator[List[Input]]):
508
- # raise UserError('Invalid inputs, inputs must be a iterator of list of Input objects.')
509
-
510
- request = self._req_iterator(inputs, method_name, inference_params, output_config)
511
-
512
- start_time = time.time()
513
- backoff_iterator = BackoffIterator(10)
514
- generation_started = False
515
- while True:
516
- if generation_started:
517
- break
518
- stream_response = self.STUB.StreamModelOutputs(request)
519
- for response in stream_response:
520
- if status_is_retryable(response.status.code) and \
521
- time.time() - start_time < 60 * 10:
522
- logger.info("Model is still deploying, please wait...")
523
- time.sleep(next(backoff_iterator))
524
- break
75
+ method_signatures = None
76
+ if response.status.code == status_code_pb2.SUCCESS:
77
+ method_signatures = response.model_version.method_signatures
78
+ if response.status.code != status_code_pb2.SUCCESS:
79
+ raise Exception(f"Model failed with response {response!r}")
80
+ self._method_signatures = {}
81
+ for method_signature in method_signatures:
82
+ method_name = method_signature.name
83
+ # check for duplicate method names
84
+ if method_name in self._method_signatures:
85
+ raise ValueError(f"Duplicate method name {method_name}")
86
+ self._method_signatures[method_name] = method_signature
87
+ if not self._method_signatures: # if no method signatures, try to fetch from the model
88
+ self._fetch_signatures_backup()
89
+ except Exception:
90
+ # try to fetch from the model
91
+ self._fetch_signatures_backup()
92
+ if not self._method_signatures:
93
+ raise ValueError("Failed to fetch method signatures from model and backup method")
94
+
95
+ def _fetch_signatures_backup(self):
96
+ '''
97
+ This is a temporary method of fetching the method signatures from the model.
98
+
99
+ Returns:
100
+ Dict: The method signatures.
101
+ '''
102
+
103
+ request = service_pb2.PostModelOutputsRequest()
104
+ request.CopyFrom(self.request_template)
105
+ # request.model.model_version.output_info.params['_method_name'] = '_GET_SIGNATURES'
106
+ inp = request.inputs.add() # empty input for this method
107
+ inp.data.parts.add() # empty part for this input
108
+ inp.data.metadata['_method_name'] = '_GET_SIGNATURES'
109
+ start_time = time.time()
110
+ backoff_iterator = BackoffIterator(10)
111
+ while True:
112
+ response = self.STUB.PostModelOutputs(request)
113
+ if (
114
+ status_is_retryable(response.status.code) and time.time() - start_time < 60 * 10
115
+ ): # 10 minutes
116
+ logger.info(f"Retrying model info fetch with response {response.status!r}")
117
+ time.sleep(next(backoff_iterator))
118
+ continue
119
+ break
120
+ if response.status.code == status_code_pb2.INPUT_UNSUPPORTED_FORMAT or (
121
+ response.status.code == status_code_pb2.SUCCESS
122
+ and response.outputs[0].data.text.raw == ''
123
+ ):
124
+ # return codes/values from older models that don't support _GET_SIGNATURES
125
+ self._method_signatures = {}
126
+ self._define_compatability_functions()
127
+ return
525
128
  if response.status.code != status_code_pb2.SUCCESS:
526
- raise Exception(f"Model Predict failed with response {response.status!r}")
527
- else:
528
- if not generation_started:
529
- generation_started = True
530
- yield response
129
+ raise Exception(f"Model failed with response {response!r}")
130
+ self._method_signatures = signatures_from_json(response.outputs[0].data.text.raw)
131
+
132
+ def _define_functions(self):
133
+ '''
134
+ Define the functions based on the method signatures.
135
+ '''
136
+ for method_name, method_signature in self._method_signatures.items():
137
+ # define the function in this client instance
138
+ if resources_pb2.RunnerMethodType.Name(method_signature.method_type) == 'UNARY_UNARY':
139
+ call_func = self._predict
140
+ elif (
141
+ resources_pb2.RunnerMethodType.Name(method_signature.method_type)
142
+ == 'UNARY_STREAMING'
143
+ ):
144
+ call_func = self._generate
145
+ elif (
146
+ resources_pb2.RunnerMethodType.Name(method_signature.method_type)
147
+ == 'STREAMING_STREAMING'
148
+ ):
149
+ call_func = self._stream
150
+ else:
151
+ raise ValueError(f"Unknown method type {method_signature.method_type}")
152
+
153
+ # method argnames, in order, collapsing nested keys to corresponding user function args
154
+ method_argnames = []
155
+ for var in method_signature.input_fields:
156
+ outer = var.name.split('.', 1)[0]
157
+ if outer in method_argnames:
158
+ continue
159
+ method_argnames.append(outer)
160
+
161
+ def bind_f(method_name, method_argnames, call_func):
162
+ def f(*args, **kwargs):
163
+ if len(args) > len(method_argnames):
164
+ raise TypeError(
165
+ f"{method_name}() takes {len(method_argnames)} positional arguments but {len(args)} were given"
166
+ )
167
+
168
+ if len(args) + len(kwargs) > len(method_argnames):
169
+ raise TypeError(
170
+ f"{method_name}() got an unexpected keyword argument {next(iter(kwargs))}"
171
+ )
172
+ if len(args) == 1 and (not kwargs) and isinstance(args[0], list):
173
+ batch_inputs = args[0]
174
+ # Validate each input is a dictionary
175
+ is_batch_input_valid = all(
176
+ isinstance(input, dict) for input in batch_inputs
177
+ )
178
+ if is_batch_input_valid and (not is_openai_chat_format(batch_inputs)):
179
+ # If the batch input is valid, call the function with the batch inputs and the method name
180
+ return call_func(batch_inputs, method_name)
181
+
182
+ for name, arg in zip(
183
+ method_argnames, args
184
+ ): # handle positional with zip shortest
185
+ if name in kwargs:
186
+ raise TypeError(f"Multiple values for argument {name}")
187
+ kwargs[name] = arg
188
+ return call_func(kwargs, method_name)
189
+
190
+ return f
191
+
192
+ # need to bind method_name to the value, not the mutating loop variable
193
+ f = bind_f(method_name, method_argnames, call_func)
194
+
195
+ # set names, annotations and docstrings
196
+ f.__name__ = method_name
197
+ f.__qualname__ = f'{self.__class__.__name__}.{method_name}'
198
+ input_annotations = code_script._get_annotations_source(method_signature)
199
+ return_annotation = input_annotations.pop('return', (None, None))[0]
200
+ sig = inspect.signature(f).replace(
201
+ parameters=[
202
+ inspect.Parameter(k, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=v[0])
203
+ for k, v in input_annotations.items()
204
+ ],
205
+ return_annotation=return_annotation,
206
+ )
207
+ f.__signature__ = sig
208
+ f.__doc__ = method_signature.description
209
+ setattr(self, method_name, f)
210
+
211
+ def available_methods(self) -> List[str]:
212
+ """Get the available methods for this model.
213
+
214
+ Returns:
215
+ List[str]: The available methods.
216
+ """
217
+ if not self._defined:
218
+ self.fetch()
219
+ return self._method_signatures.keys()
220
+
221
+ def method_signature(self, method_name: str) -> str:
222
+ """Get the method signature for a method.
223
+
224
+ Args:
225
+ method_name (str): The name of the method.
226
+
227
+ Returns:
228
+ str: The method signature.
229
+ """
230
+ if not self._defined:
231
+ self.fetch()
232
+ return method_signatures.get_method_signature(self._method_signatures[method_name])
233
+
234
+ def generate_client_script(self) -> str:
235
+ """Generate a client script for this model.
236
+
237
+ Returns:
238
+ str: The client script.
239
+ """
240
+ if not self._defined:
241
+ self.fetch()
242
+ method_signatures = []
243
+ for _, method_signature in self._method_signatures.items():
244
+ method_signatures.append(method_signature)
245
+ return code_script.generate_client_script(
246
+ method_signatures,
247
+ user_id=self.request_template.user_app_id.user_id,
248
+ app_id=self.request_template.user_app_id.app_id,
249
+ model_id=self.request_template.model_id,
250
+ )
251
+
252
+ def _define_compatability_functions(self):
253
+ serializer = CompatibilitySerializer()
254
+
255
+ def predict(input: Any) -> Any:
256
+ proto = resources_pb2.Input()
257
+ serializer.serialize(proto.data, input)
258
+ # always use text.raw for compat
259
+ if proto.data.string_value:
260
+ proto.data.text.raw = proto.data.string_value
261
+ proto.data.string_value = ''
262
+ response = self._predict_by_proto([proto])
263
+ if response.status.code != status_code_pb2.SUCCESS:
264
+ raise Exception(f"Model predict failed with response {response!r}")
265
+ response_data = response.outputs[0].data
266
+ if response_data.text.raw:
267
+ response_data.string_value = response_data.text.raw
268
+ response_data.text.raw = ''
269
+ return serializer.deserialize(response_data)
270
+
271
+ self.predict = predict
272
+
273
+ def _predict(
274
+ self,
275
+ inputs, # TODO set up functions according to fetched signatures?
276
+ method_name: str = 'predict',
277
+ ) -> Any:
278
+ input_signature = self._method_signatures[method_name].input_fields
279
+ output_signature = self._method_signatures[method_name].output_fields
280
+
281
+ batch_input = True
282
+ if isinstance(inputs, dict):
283
+ inputs = [inputs]
284
+ batch_input = False
285
+
286
+ proto_inputs = []
287
+ for input in inputs:
288
+ proto = resources_pb2.Input()
289
+
290
+ serialize(input, input_signature, proto.data)
291
+ proto_inputs.append(proto)
292
+
293
+ response = self._predict_by_proto(proto_inputs, method_name)
294
+
295
+ outputs = []
296
+ for output in response.outputs:
297
+ outputs.append(deserialize(output.data, output_signature, is_output=True))
298
+ if batch_input:
299
+ return outputs
300
+ return outputs[0]
301
+
302
+ def _predict_by_proto(
303
+ self,
304
+ inputs: List[resources_pb2.Input],
305
+ method_name: str = None,
306
+ inference_params: Dict = None,
307
+ output_config: Dict = None,
308
+ ) -> service_pb2.MultiOutputResponse:
309
+ """Predicts the model based on the given inputs.
310
+
311
+ Args:
312
+ inputs (List[resources_pb2.Input]): The inputs to predict.
313
+ method_name (str): The remote method name to call.
314
+ inference_params (Dict): Inference parameters to override.
315
+ output_config (Dict): Output configuration to override.
316
+
317
+ Returns:
318
+ service_pb2.MultiOutputResponse: The prediction response(s).
319
+ """
320
+ if not isinstance(inputs, list):
321
+ raise UserError('Invalid inputs, inputs must be a list of Input objects.')
322
+ if len(inputs) > MAX_MODEL_PREDICT_INPUTS:
323
+ raise UserError(f"Too many inputs. Max is {MAX_MODEL_PREDICT_INPUTS}.")
324
+
325
+ request = service_pb2.PostModelOutputsRequest()
326
+ request.CopyFrom(self.request_template)
327
+
328
+ request.inputs.extend(inputs)
329
+
330
+ if method_name:
331
+ # TODO put in new proto field?
332
+ for inp in request.inputs:
333
+ inp.data.metadata['_method_name'] = method_name
334
+ if inference_params:
335
+ request.model.model_version.output_info.params.update(inference_params)
336
+ if output_config:
337
+ request.model.model_version.output_info.output_config.MergeFrom(
338
+ resources_pb2.OutputConfig(**output_config)
339
+ )
340
+
341
+ start_time = time.time()
342
+ backoff_iterator = BackoffIterator(10)
343
+ while True:
344
+ response = self.STUB.PostModelOutputs(request)
345
+ if (
346
+ status_is_retryable(response.status.code) and time.time() - start_time < 60 * 10
347
+ ): # 10 minutes
348
+ logger.info("Model is still deploying, please wait...")
349
+ time.sleep(next(backoff_iterator))
350
+ continue
351
+
352
+ if response.status.code != status_code_pb2.SUCCESS:
353
+ raise Exception(f"Model predict failed with response {response!r}")
354
+ break
355
+ return response
356
+
357
+ def _generate(
358
+ self,
359
+ inputs, # TODO set up functions according to fetched signatures?
360
+ method_name: str = 'generate',
361
+ ) -> Any:
362
+ input_signature = self._method_signatures[method_name].input_fields
363
+ output_signature = self._method_signatures[method_name].output_fields
364
+
365
+ batch_input = True
366
+ if isinstance(inputs, dict):
367
+ inputs = [inputs]
368
+ batch_input = False
369
+
370
+ proto_inputs = []
371
+ for input in inputs:
372
+ proto = resources_pb2.Input()
373
+ serialize(input, input_signature, proto.data)
374
+ proto_inputs.append(proto)
375
+
376
+ response_stream = self._generate_by_proto(proto_inputs, method_name)
377
+
378
+ for response in response_stream:
379
+ outputs = []
380
+ for output in response.outputs:
381
+ outputs.append(deserialize(output.data, output_signature, is_output=True))
382
+ if batch_input:
383
+ yield outputs
384
+ else:
385
+ yield outputs[0]
386
+
387
+ def _generate_by_proto(
388
+ self,
389
+ inputs: List[resources_pb2.Input],
390
+ method_name: str = None,
391
+ inference_params: Dict = {},
392
+ output_config: Dict = {},
393
+ ):
394
+ """Generate the stream output on model based on the given inputs.
395
+
396
+ Args:
397
+ inputs (list[Input]): The inputs to generate, must be less than 128.
398
+ method_name (str): The remote method name to call.
399
+ inference_params (dict): The inference params to override.
400
+ output_config (dict): The output config to override.
401
+ """
402
+ if not isinstance(inputs, list):
403
+ raise UserError('Invalid inputs, inputs must be a list of Input objects.')
404
+ if len(inputs) > MAX_MODEL_PREDICT_INPUTS:
405
+ raise UserError(
406
+ f"Too many inputs. Max is {MAX_MODEL_PREDICT_INPUTS}."
407
+ ) # TODO Use Chunker for inputs len > 128
408
+
409
+ request = service_pb2.PostModelOutputsRequest()
410
+ request.CopyFrom(self.request_template)
411
+
412
+ request.inputs.extend(inputs)
413
+
414
+ if method_name:
415
+ # TODO put in new proto field?
416
+ for inp in request.inputs:
417
+ inp.data.metadata['_method_name'] = method_name
418
+ if inference_params:
419
+ request.model.model_version.output_info.params.update(inference_params)
420
+ if output_config:
421
+ request.model.model_version.output_info.output_config.MergeFromDict(output_config)
422
+
423
+ start_time = time.time()
424
+ backoff_iterator = BackoffIterator(10)
425
+ started = False
426
+ while not started:
427
+ stream_response = self.STUB.GenerateModelOutputs(request)
428
+ try:
429
+ response = next(stream_response) # get the first response
430
+ except StopIteration:
431
+ raise Exception("Model Generate failed with no response")
432
+ if status_is_retryable(response.status.code) and time.time() - start_time < 60 * 10:
433
+ logger.info("Model is still deploying, please wait...")
434
+ time.sleep(next(backoff_iterator))
435
+ continue
436
+ if response.status.code != status_code_pb2.SUCCESS:
437
+ raise Exception(f"Model Generate failed with response {response.status!r}")
438
+ started = True
439
+
440
+ yield response # yield the first response
441
+
442
+ for response in stream_response:
443
+ if response.status.code != status_code_pb2.SUCCESS:
444
+ raise Exception(f"Model Generate failed with response {response.status!r}")
445
+ yield response
446
+
447
+ def _stream(
448
+ self,
449
+ inputs,
450
+ method_name: str = 'stream',
451
+ ) -> Any:
452
+ input_signature = self._method_signatures[method_name].input_fields
453
+ output_signature = self._method_signatures[method_name].output_fields
454
+
455
+ if isinstance(inputs, list):
456
+ assert len(inputs) == 1, 'streaming methods do not support batched calls'
457
+ inputs = inputs[0]
458
+ assert isinstance(inputs, dict)
459
+ kwargs = inputs
460
+
461
+ # find the streaming vars in the input signature, and the streaming input python param
462
+ stream_sig = get_stream_from_signature(input_signature)
463
+ if stream_sig is None:
464
+ raise ValueError("Streaming method must have a Stream input")
465
+ stream_argname = stream_sig.name
466
+
467
+ # get the streaming input generator from the user-provided function arg values
468
+ user_inputs_generator = kwargs.pop(stream_argname)
469
+
470
+ def _input_proto_stream():
471
+ # first item contains all the inputs and the first stream item
472
+ proto = resources_pb2.Input()
473
+ try:
474
+ item = next(user_inputs_generator)
475
+ except StopIteration:
476
+ return # no items to stream
477
+ kwargs[stream_argname] = item
478
+ serialize(kwargs, input_signature, proto.data)
479
+
480
+ yield proto
481
+
482
+ # subsequent items are just the stream items
483
+ for item in user_inputs_generator:
484
+ proto = resources_pb2.Input()
485
+ serialize({stream_argname: item}, [stream_sig], proto.data)
486
+ yield proto
487
+
488
+ response_stream = self._stream_by_proto(_input_proto_stream(), method_name)
489
+
490
+ for response in response_stream:
491
+ assert len(response.outputs) == 1, 'streaming methods must have exactly one output'
492
+ yield deserialize(response.outputs[0].data, output_signature, is_output=True)
493
+
494
+ def _req_iterator(
495
+ self,
496
+ input_iterator: Iterator[List[resources_pb2.Input]],
497
+ method_name: str = None,
498
+ inference_params: Dict = {},
499
+ output_config: Dict = {},
500
+ ):
501
+ request = service_pb2.PostModelOutputsRequest()
502
+ request.CopyFrom(self.request_template)
503
+ if inference_params:
504
+ request.model.model_version.output_info.params.update(inference_params)
505
+ if output_config:
506
+ request.model.model_version.output_info.output_config.MergeFromDict(output_config)
507
+ for inputs in input_iterator:
508
+ req = service_pb2.PostModelOutputsRequest()
509
+ req.CopyFrom(request)
510
+ if isinstance(inputs, list):
511
+ req.inputs.extend(inputs)
512
+ else:
513
+ req.inputs.append(inputs)
514
+ # TODO: put into new proto field?
515
+ if method_name:
516
+ for inp in req.inputs:
517
+ inp.data.metadata['_method_name'] = method_name
518
+ yield req
519
+
520
+ def _stream_by_proto(
521
+ self,
522
+ inputs: Iterator[List[resources_pb2.Input]],
523
+ method_name: str = None,
524
+ inference_params: Dict = {},
525
+ output_config: Dict = {},
526
+ ):
527
+ """Generate the stream output on model based on the given stream of inputs."""
528
+ # if not isinstance(inputs, Iterator[List[Input]]):
529
+ # raise UserError('Invalid inputs, inputs must be a iterator of list of Input objects.')
530
+
531
+ request = self._req_iterator(inputs, method_name, inference_params, output_config)
532
+
533
+ start_time = time.time()
534
+ backoff_iterator = BackoffIterator(10)
535
+ generation_started = False
536
+ while True:
537
+ if generation_started:
538
+ break
539
+ stream_response = self.STUB.StreamModelOutputs(request)
540
+ for response in stream_response:
541
+ if (
542
+ status_is_retryable(response.status.code)
543
+ and time.time() - start_time < 60 * 10
544
+ ):
545
+ logger.info("Model is still deploying, please wait...")
546
+ time.sleep(next(backoff_iterator))
547
+ break
548
+ if response.status.code != status_code_pb2.SUCCESS:
549
+ raise Exception(f"Model Predict failed with response {response.status!r}")
550
+ else:
551
+ if not generation_started:
552
+ generation_started = True
553
+ yield response