clarifai 11.3.0rc2__py3-none-any.whl → 11.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (300) hide show
  1. clarifai/__init__.py +1 -1
  2. clarifai/cli/__main__.py +1 -1
  3. clarifai/cli/base.py +144 -136
  4. clarifai/cli/compute_cluster.py +45 -31
  5. clarifai/cli/deployment.py +93 -76
  6. clarifai/cli/model.py +578 -180
  7. clarifai/cli/nodepool.py +100 -82
  8. clarifai/client/__init__.py +12 -2
  9. clarifai/client/app.py +973 -911
  10. clarifai/client/auth/helper.py +345 -342
  11. clarifai/client/auth/register.py +7 -7
  12. clarifai/client/auth/stub.py +107 -106
  13. clarifai/client/base.py +185 -178
  14. clarifai/client/compute_cluster.py +214 -180
  15. clarifai/client/dataset.py +793 -698
  16. clarifai/client/deployment.py +55 -50
  17. clarifai/client/input.py +1223 -1088
  18. clarifai/client/lister.py +47 -45
  19. clarifai/client/model.py +1939 -1717
  20. clarifai/client/model_client.py +525 -502
  21. clarifai/client/module.py +82 -73
  22. clarifai/client/nodepool.py +358 -213
  23. clarifai/client/runner.py +58 -0
  24. clarifai/client/search.py +342 -309
  25. clarifai/client/user.py +419 -414
  26. clarifai/client/workflow.py +294 -274
  27. clarifai/constants/dataset.py +11 -17
  28. clarifai/constants/model.py +8 -2
  29. clarifai/datasets/export/inputs_annotations.py +233 -217
  30. clarifai/datasets/upload/base.py +63 -51
  31. clarifai/datasets/upload/features.py +43 -38
  32. clarifai/datasets/upload/image.py +237 -207
  33. clarifai/datasets/upload/loaders/coco_captions.py +34 -32
  34. clarifai/datasets/upload/loaders/coco_detection.py +72 -65
  35. clarifai/datasets/upload/loaders/imagenet_classification.py +57 -53
  36. clarifai/datasets/upload/loaders/xview_detection.py +274 -132
  37. clarifai/datasets/upload/multimodal.py +55 -46
  38. clarifai/datasets/upload/text.py +55 -47
  39. clarifai/datasets/upload/utils.py +250 -234
  40. clarifai/errors.py +51 -50
  41. clarifai/models/api.py +260 -238
  42. clarifai/modules/css.py +50 -50
  43. clarifai/modules/pages.py +33 -33
  44. clarifai/rag/rag.py +312 -288
  45. clarifai/rag/utils.py +91 -84
  46. clarifai/runners/models/model_builder.py +906 -802
  47. clarifai/runners/models/model_class.py +370 -331
  48. clarifai/runners/models/model_run_locally.py +459 -419
  49. clarifai/runners/models/model_runner.py +170 -162
  50. clarifai/runners/models/model_servicer.py +78 -70
  51. clarifai/runners/server.py +111 -101
  52. clarifai/runners/utils/code_script.py +225 -187
  53. clarifai/runners/utils/const.py +4 -1
  54. clarifai/runners/utils/data_types/__init__.py +12 -0
  55. clarifai/runners/utils/data_types/data_types.py +598 -0
  56. clarifai/runners/utils/data_utils.py +387 -440
  57. clarifai/runners/utils/loader.py +247 -227
  58. clarifai/runners/utils/method_signatures.py +411 -386
  59. clarifai/runners/utils/openai_convertor.py +108 -109
  60. clarifai/runners/utils/serializers.py +175 -179
  61. clarifai/runners/utils/url_fetcher.py +35 -35
  62. clarifai/schema/search.py +56 -63
  63. clarifai/urls/helper.py +125 -102
  64. clarifai/utils/cli.py +129 -123
  65. clarifai/utils/config.py +127 -87
  66. clarifai/utils/constants.py +49 -0
  67. clarifai/utils/evaluation/helpers.py +503 -466
  68. clarifai/utils/evaluation/main.py +431 -393
  69. clarifai/utils/evaluation/testset_annotation_parser.py +154 -144
  70. clarifai/utils/logging.py +324 -306
  71. clarifai/utils/misc.py +60 -56
  72. clarifai/utils/model_train.py +165 -146
  73. clarifai/utils/protobuf.py +126 -103
  74. clarifai/versions.py +3 -1
  75. clarifai/workflows/export.py +48 -50
  76. clarifai/workflows/utils.py +39 -36
  77. clarifai/workflows/validate.py +55 -43
  78. {clarifai-11.3.0rc2.dist-info → clarifai-11.4.0.dist-info}/METADATA +16 -6
  79. clarifai-11.4.0.dist-info/RECORD +109 -0
  80. {clarifai-11.3.0rc2.dist-info → clarifai-11.4.0.dist-info}/WHEEL +1 -1
  81. clarifai/__pycache__/__init__.cpython-310.pyc +0 -0
  82. clarifai/__pycache__/__init__.cpython-311.pyc +0 -0
  83. clarifai/__pycache__/__init__.cpython-39.pyc +0 -0
  84. clarifai/__pycache__/errors.cpython-310.pyc +0 -0
  85. clarifai/__pycache__/errors.cpython-311.pyc +0 -0
  86. clarifai/__pycache__/versions.cpython-310.pyc +0 -0
  87. clarifai/__pycache__/versions.cpython-311.pyc +0 -0
  88. clarifai/cli/__pycache__/__init__.cpython-310.pyc +0 -0
  89. clarifai/cli/__pycache__/__init__.cpython-311.pyc +0 -0
  90. clarifai/cli/__pycache__/base.cpython-310.pyc +0 -0
  91. clarifai/cli/__pycache__/base.cpython-311.pyc +0 -0
  92. clarifai/cli/__pycache__/base_cli.cpython-310.pyc +0 -0
  93. clarifai/cli/__pycache__/compute_cluster.cpython-310.pyc +0 -0
  94. clarifai/cli/__pycache__/compute_cluster.cpython-311.pyc +0 -0
  95. clarifai/cli/__pycache__/deployment.cpython-310.pyc +0 -0
  96. clarifai/cli/__pycache__/deployment.cpython-311.pyc +0 -0
  97. clarifai/cli/__pycache__/model.cpython-310.pyc +0 -0
  98. clarifai/cli/__pycache__/model.cpython-311.pyc +0 -0
  99. clarifai/cli/__pycache__/model_cli.cpython-310.pyc +0 -0
  100. clarifai/cli/__pycache__/nodepool.cpython-310.pyc +0 -0
  101. clarifai/cli/__pycache__/nodepool.cpython-311.pyc +0 -0
  102. clarifai/client/__pycache__/__init__.cpython-310.pyc +0 -0
  103. clarifai/client/__pycache__/__init__.cpython-311.pyc +0 -0
  104. clarifai/client/__pycache__/__init__.cpython-39.pyc +0 -0
  105. clarifai/client/__pycache__/app.cpython-310.pyc +0 -0
  106. clarifai/client/__pycache__/app.cpython-311.pyc +0 -0
  107. clarifai/client/__pycache__/app.cpython-39.pyc +0 -0
  108. clarifai/client/__pycache__/base.cpython-310.pyc +0 -0
  109. clarifai/client/__pycache__/base.cpython-311.pyc +0 -0
  110. clarifai/client/__pycache__/compute_cluster.cpython-310.pyc +0 -0
  111. clarifai/client/__pycache__/compute_cluster.cpython-311.pyc +0 -0
  112. clarifai/client/__pycache__/dataset.cpython-310.pyc +0 -0
  113. clarifai/client/__pycache__/dataset.cpython-311.pyc +0 -0
  114. clarifai/client/__pycache__/deployment.cpython-310.pyc +0 -0
  115. clarifai/client/__pycache__/deployment.cpython-311.pyc +0 -0
  116. clarifai/client/__pycache__/input.cpython-310.pyc +0 -0
  117. clarifai/client/__pycache__/input.cpython-311.pyc +0 -0
  118. clarifai/client/__pycache__/lister.cpython-310.pyc +0 -0
  119. clarifai/client/__pycache__/lister.cpython-311.pyc +0 -0
  120. clarifai/client/__pycache__/model.cpython-310.pyc +0 -0
  121. clarifai/client/__pycache__/model.cpython-311.pyc +0 -0
  122. clarifai/client/__pycache__/module.cpython-310.pyc +0 -0
  123. clarifai/client/__pycache__/module.cpython-311.pyc +0 -0
  124. clarifai/client/__pycache__/nodepool.cpython-310.pyc +0 -0
  125. clarifai/client/__pycache__/nodepool.cpython-311.pyc +0 -0
  126. clarifai/client/__pycache__/search.cpython-310.pyc +0 -0
  127. clarifai/client/__pycache__/search.cpython-311.pyc +0 -0
  128. clarifai/client/__pycache__/user.cpython-310.pyc +0 -0
  129. clarifai/client/__pycache__/user.cpython-311.pyc +0 -0
  130. clarifai/client/__pycache__/workflow.cpython-310.pyc +0 -0
  131. clarifai/client/__pycache__/workflow.cpython-311.pyc +0 -0
  132. clarifai/client/auth/__pycache__/__init__.cpython-310.pyc +0 -0
  133. clarifai/client/auth/__pycache__/__init__.cpython-311.pyc +0 -0
  134. clarifai/client/auth/__pycache__/helper.cpython-310.pyc +0 -0
  135. clarifai/client/auth/__pycache__/helper.cpython-311.pyc +0 -0
  136. clarifai/client/auth/__pycache__/register.cpython-310.pyc +0 -0
  137. clarifai/client/auth/__pycache__/register.cpython-311.pyc +0 -0
  138. clarifai/client/auth/__pycache__/stub.cpython-310.pyc +0 -0
  139. clarifai/client/auth/__pycache__/stub.cpython-311.pyc +0 -0
  140. clarifai/client/cli/__init__.py +0 -0
  141. clarifai/client/cli/__pycache__/__init__.cpython-310.pyc +0 -0
  142. clarifai/client/cli/__pycache__/base_cli.cpython-310.pyc +0 -0
  143. clarifai/client/cli/__pycache__/model_cli.cpython-310.pyc +0 -0
  144. clarifai/client/cli/base_cli.py +0 -88
  145. clarifai/client/cli/model_cli.py +0 -29
  146. clarifai/constants/__pycache__/base.cpython-310.pyc +0 -0
  147. clarifai/constants/__pycache__/base.cpython-311.pyc +0 -0
  148. clarifai/constants/__pycache__/dataset.cpython-310.pyc +0 -0
  149. clarifai/constants/__pycache__/dataset.cpython-311.pyc +0 -0
  150. clarifai/constants/__pycache__/input.cpython-310.pyc +0 -0
  151. clarifai/constants/__pycache__/input.cpython-311.pyc +0 -0
  152. clarifai/constants/__pycache__/model.cpython-310.pyc +0 -0
  153. clarifai/constants/__pycache__/model.cpython-311.pyc +0 -0
  154. clarifai/constants/__pycache__/rag.cpython-310.pyc +0 -0
  155. clarifai/constants/__pycache__/rag.cpython-311.pyc +0 -0
  156. clarifai/constants/__pycache__/search.cpython-310.pyc +0 -0
  157. clarifai/constants/__pycache__/search.cpython-311.pyc +0 -0
  158. clarifai/constants/__pycache__/workflow.cpython-310.pyc +0 -0
  159. clarifai/constants/__pycache__/workflow.cpython-311.pyc +0 -0
  160. clarifai/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
  161. clarifai/datasets/__pycache__/__init__.cpython-311.pyc +0 -0
  162. clarifai/datasets/__pycache__/__init__.cpython-39.pyc +0 -0
  163. clarifai/datasets/export/__pycache__/__init__.cpython-310.pyc +0 -0
  164. clarifai/datasets/export/__pycache__/__init__.cpython-311.pyc +0 -0
  165. clarifai/datasets/export/__pycache__/__init__.cpython-39.pyc +0 -0
  166. clarifai/datasets/export/__pycache__/inputs_annotations.cpython-310.pyc +0 -0
  167. clarifai/datasets/export/__pycache__/inputs_annotations.cpython-311.pyc +0 -0
  168. clarifai/datasets/upload/__pycache__/__init__.cpython-310.pyc +0 -0
  169. clarifai/datasets/upload/__pycache__/__init__.cpython-311.pyc +0 -0
  170. clarifai/datasets/upload/__pycache__/__init__.cpython-39.pyc +0 -0
  171. clarifai/datasets/upload/__pycache__/base.cpython-310.pyc +0 -0
  172. clarifai/datasets/upload/__pycache__/base.cpython-311.pyc +0 -0
  173. clarifai/datasets/upload/__pycache__/features.cpython-310.pyc +0 -0
  174. clarifai/datasets/upload/__pycache__/features.cpython-311.pyc +0 -0
  175. clarifai/datasets/upload/__pycache__/image.cpython-310.pyc +0 -0
  176. clarifai/datasets/upload/__pycache__/image.cpython-311.pyc +0 -0
  177. clarifai/datasets/upload/__pycache__/multimodal.cpython-310.pyc +0 -0
  178. clarifai/datasets/upload/__pycache__/multimodal.cpython-311.pyc +0 -0
  179. clarifai/datasets/upload/__pycache__/text.cpython-310.pyc +0 -0
  180. clarifai/datasets/upload/__pycache__/text.cpython-311.pyc +0 -0
  181. clarifai/datasets/upload/__pycache__/utils.cpython-310.pyc +0 -0
  182. clarifai/datasets/upload/__pycache__/utils.cpython-311.pyc +0 -0
  183. clarifai/datasets/upload/loaders/__pycache__/__init__.cpython-311.pyc +0 -0
  184. clarifai/datasets/upload/loaders/__pycache__/__init__.cpython-39.pyc +0 -0
  185. clarifai/datasets/upload/loaders/__pycache__/coco_detection.cpython-311.pyc +0 -0
  186. clarifai/datasets/upload/loaders/__pycache__/imagenet_classification.cpython-311.pyc +0 -0
  187. clarifai/models/__pycache__/__init__.cpython-39.pyc +0 -0
  188. clarifai/modules/__pycache__/__init__.cpython-39.pyc +0 -0
  189. clarifai/rag/__pycache__/__init__.cpython-310.pyc +0 -0
  190. clarifai/rag/__pycache__/__init__.cpython-311.pyc +0 -0
  191. clarifai/rag/__pycache__/__init__.cpython-39.pyc +0 -0
  192. clarifai/rag/__pycache__/rag.cpython-310.pyc +0 -0
  193. clarifai/rag/__pycache__/rag.cpython-311.pyc +0 -0
  194. clarifai/rag/__pycache__/rag.cpython-39.pyc +0 -0
  195. clarifai/rag/__pycache__/utils.cpython-310.pyc +0 -0
  196. clarifai/rag/__pycache__/utils.cpython-311.pyc +0 -0
  197. clarifai/runners/__pycache__/__init__.cpython-310.pyc +0 -0
  198. clarifai/runners/__pycache__/__init__.cpython-311.pyc +0 -0
  199. clarifai/runners/__pycache__/__init__.cpython-39.pyc +0 -0
  200. clarifai/runners/dockerfile_template/Dockerfile.cpu.template +0 -31
  201. clarifai/runners/dockerfile_template/Dockerfile.cuda.template +0 -42
  202. clarifai/runners/dockerfile_template/Dockerfile.nim +0 -71
  203. clarifai/runners/models/__pycache__/__init__.cpython-310.pyc +0 -0
  204. clarifai/runners/models/__pycache__/__init__.cpython-311.pyc +0 -0
  205. clarifai/runners/models/__pycache__/__init__.cpython-39.pyc +0 -0
  206. clarifai/runners/models/__pycache__/base_typed_model.cpython-310.pyc +0 -0
  207. clarifai/runners/models/__pycache__/base_typed_model.cpython-311.pyc +0 -0
  208. clarifai/runners/models/__pycache__/base_typed_model.cpython-39.pyc +0 -0
  209. clarifai/runners/models/__pycache__/model_builder.cpython-311.pyc +0 -0
  210. clarifai/runners/models/__pycache__/model_class.cpython-310.pyc +0 -0
  211. clarifai/runners/models/__pycache__/model_class.cpython-311.pyc +0 -0
  212. clarifai/runners/models/__pycache__/model_run_locally.cpython-310-pytest-7.1.2.pyc +0 -0
  213. clarifai/runners/models/__pycache__/model_run_locally.cpython-310.pyc +0 -0
  214. clarifai/runners/models/__pycache__/model_run_locally.cpython-311.pyc +0 -0
  215. clarifai/runners/models/__pycache__/model_runner.cpython-310.pyc +0 -0
  216. clarifai/runners/models/__pycache__/model_runner.cpython-311.pyc +0 -0
  217. clarifai/runners/models/__pycache__/model_upload.cpython-310.pyc +0 -0
  218. clarifai/runners/models/base_typed_model.py +0 -238
  219. clarifai/runners/models/model_class_refract.py +0 -80
  220. clarifai/runners/models/model_upload.py +0 -607
  221. clarifai/runners/models/temp.py +0 -25
  222. clarifai/runners/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  223. clarifai/runners/utils/__pycache__/__init__.cpython-311.pyc +0 -0
  224. clarifai/runners/utils/__pycache__/__init__.cpython-38.pyc +0 -0
  225. clarifai/runners/utils/__pycache__/__init__.cpython-39.pyc +0 -0
  226. clarifai/runners/utils/__pycache__/buffered_stream.cpython-310.pyc +0 -0
  227. clarifai/runners/utils/__pycache__/buffered_stream.cpython-38.pyc +0 -0
  228. clarifai/runners/utils/__pycache__/buffered_stream.cpython-39.pyc +0 -0
  229. clarifai/runners/utils/__pycache__/const.cpython-310.pyc +0 -0
  230. clarifai/runners/utils/__pycache__/const.cpython-311.pyc +0 -0
  231. clarifai/runners/utils/__pycache__/constants.cpython-310.pyc +0 -0
  232. clarifai/runners/utils/__pycache__/constants.cpython-38.pyc +0 -0
  233. clarifai/runners/utils/__pycache__/constants.cpython-39.pyc +0 -0
  234. clarifai/runners/utils/__pycache__/data_handler.cpython-310.pyc +0 -0
  235. clarifai/runners/utils/__pycache__/data_handler.cpython-311.pyc +0 -0
  236. clarifai/runners/utils/__pycache__/data_handler.cpython-38.pyc +0 -0
  237. clarifai/runners/utils/__pycache__/data_handler.cpython-39.pyc +0 -0
  238. clarifai/runners/utils/__pycache__/data_utils.cpython-310.pyc +0 -0
  239. clarifai/runners/utils/__pycache__/data_utils.cpython-311.pyc +0 -0
  240. clarifai/runners/utils/__pycache__/data_utils.cpython-38.pyc +0 -0
  241. clarifai/runners/utils/__pycache__/data_utils.cpython-39.pyc +0 -0
  242. clarifai/runners/utils/__pycache__/grpc_server.cpython-310.pyc +0 -0
  243. clarifai/runners/utils/__pycache__/grpc_server.cpython-38.pyc +0 -0
  244. clarifai/runners/utils/__pycache__/grpc_server.cpython-39.pyc +0 -0
  245. clarifai/runners/utils/__pycache__/health.cpython-310.pyc +0 -0
  246. clarifai/runners/utils/__pycache__/health.cpython-38.pyc +0 -0
  247. clarifai/runners/utils/__pycache__/health.cpython-39.pyc +0 -0
  248. clarifai/runners/utils/__pycache__/loader.cpython-310.pyc +0 -0
  249. clarifai/runners/utils/__pycache__/loader.cpython-311.pyc +0 -0
  250. clarifai/runners/utils/__pycache__/logging.cpython-310.pyc +0 -0
  251. clarifai/runners/utils/__pycache__/logging.cpython-38.pyc +0 -0
  252. clarifai/runners/utils/__pycache__/logging.cpython-39.pyc +0 -0
  253. clarifai/runners/utils/__pycache__/stream_source.cpython-310.pyc +0 -0
  254. clarifai/runners/utils/__pycache__/stream_source.cpython-39.pyc +0 -0
  255. clarifai/runners/utils/__pycache__/url_fetcher.cpython-310.pyc +0 -0
  256. clarifai/runners/utils/__pycache__/url_fetcher.cpython-311.pyc +0 -0
  257. clarifai/runners/utils/__pycache__/url_fetcher.cpython-38.pyc +0 -0
  258. clarifai/runners/utils/__pycache__/url_fetcher.cpython-39.pyc +0 -0
  259. clarifai/runners/utils/data_handler.py +0 -231
  260. clarifai/runners/utils/data_handler_refract.py +0 -213
  261. clarifai/runners/utils/data_types.py +0 -469
  262. clarifai/runners/utils/logger.py +0 -0
  263. clarifai/runners/utils/openai_format.py +0 -87
  264. clarifai/schema/__pycache__/search.cpython-310.pyc +0 -0
  265. clarifai/schema/__pycache__/search.cpython-311.pyc +0 -0
  266. clarifai/urls/__pycache__/helper.cpython-310.pyc +0 -0
  267. clarifai/urls/__pycache__/helper.cpython-311.pyc +0 -0
  268. clarifai/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  269. clarifai/utils/__pycache__/__init__.cpython-311.pyc +0 -0
  270. clarifai/utils/__pycache__/__init__.cpython-39.pyc +0 -0
  271. clarifai/utils/__pycache__/cli.cpython-310.pyc +0 -0
  272. clarifai/utils/__pycache__/cli.cpython-311.pyc +0 -0
  273. clarifai/utils/__pycache__/config.cpython-311.pyc +0 -0
  274. clarifai/utils/__pycache__/constants.cpython-310.pyc +0 -0
  275. clarifai/utils/__pycache__/constants.cpython-311.pyc +0 -0
  276. clarifai/utils/__pycache__/logging.cpython-310.pyc +0 -0
  277. clarifai/utils/__pycache__/logging.cpython-311.pyc +0 -0
  278. clarifai/utils/__pycache__/misc.cpython-310.pyc +0 -0
  279. clarifai/utils/__pycache__/misc.cpython-311.pyc +0 -0
  280. clarifai/utils/__pycache__/model_train.cpython-310.pyc +0 -0
  281. clarifai/utils/__pycache__/model_train.cpython-311.pyc +0 -0
  282. clarifai/utils/__pycache__/protobuf.cpython-311.pyc +0 -0
  283. clarifai/utils/evaluation/__pycache__/__init__.cpython-311.pyc +0 -0
  284. clarifai/utils/evaluation/__pycache__/__init__.cpython-39.pyc +0 -0
  285. clarifai/utils/evaluation/__pycache__/helpers.cpython-311.pyc +0 -0
  286. clarifai/utils/evaluation/__pycache__/main.cpython-311.pyc +0 -0
  287. clarifai/utils/evaluation/__pycache__/main.cpython-39.pyc +0 -0
  288. clarifai/workflows/__pycache__/__init__.cpython-310.pyc +0 -0
  289. clarifai/workflows/__pycache__/__init__.cpython-311.pyc +0 -0
  290. clarifai/workflows/__pycache__/__init__.cpython-39.pyc +0 -0
  291. clarifai/workflows/__pycache__/export.cpython-310.pyc +0 -0
  292. clarifai/workflows/__pycache__/export.cpython-311.pyc +0 -0
  293. clarifai/workflows/__pycache__/utils.cpython-310.pyc +0 -0
  294. clarifai/workflows/__pycache__/utils.cpython-311.pyc +0 -0
  295. clarifai/workflows/__pycache__/validate.cpython-310.pyc +0 -0
  296. clarifai/workflows/__pycache__/validate.cpython-311.pyc +0 -0
  297. clarifai-11.3.0rc2.dist-info/RECORD +0 -322
  298. {clarifai-11.3.0rc2.dist-info → clarifai-11.4.0.dist-info}/entry_points.txt +0 -0
  299. {clarifai-11.3.0rc2.dist-info → clarifai-11.4.0.dist-info/licenses}/LICENSE +0 -0
  300. {clarifai-11.3.0rc2.dist-info → clarifai-11.4.0.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,10 @@
1
- import collections.abc as abc
2
1
  import inspect
3
2
  import itertools
4
3
  import logging
5
4
  import os
6
5
  import traceback
7
6
  from abc import ABC
7
+ from collections import abc
8
8
  from typing import Any, Dict, Iterator, List
9
9
 
10
10
  from clarifai_grpc.grpc.api import resources_pb2, service_pb2
@@ -13,9 +13,13 @@ from google.protobuf import json_format
13
13
 
14
14
  from clarifai.runners.utils import data_types
15
15
  from clarifai.runners.utils.data_utils import DataConverter
16
- from clarifai.runners.utils.method_signatures import (build_function_signature, deserialize,
17
- get_stream_from_signature, serialize,
18
- signatures_to_json)
16
+ from clarifai.runners.utils.method_signatures import (
17
+ build_function_signature,
18
+ deserialize,
19
+ get_stream_from_signature,
20
+ serialize,
21
+ signatures_to_json,
22
+ )
19
23
 
20
24
  _METHOD_INFO_ATTR = '_cf_method_info'
21
25
 
@@ -23,337 +27,372 @@ _RAISE_EXCEPTIONS = os.getenv("RAISE_EXCEPTIONS", "false").lower() in ("true", "
23
27
 
24
28
 
25
29
  class ModelClass(ABC):
26
- '''
27
- Base class for model classes that can be run as a service.
28
-
29
- Define predict, generate, or stream methods using the @ModelClass.method decorator.
30
-
31
- Example:
32
-
33
- from clarifai.runners.model_class import ModelClass
34
- from clarifai.runners.utils.data_types import NamedFields
35
- from typing import List, Iterator
36
-
37
- class MyModel(ModelClass):
38
-
39
- @ModelClass.method
40
- def predict(self, x: str, y: int) -> List[str]:
41
- return [x] * y
42
-
43
- @ModelClass.method
44
- def generate(self, x: str, y: int) -> Iterator[str]:
45
- for i in range(y):
46
- yield x + str(i)
47
-
48
- @ModelClass.method
49
- def stream(self, input_stream: Iterator[NamedFields(x=str, y=int)]) -> Iterator[str]:
50
- for item in input_stream:
51
- yield item.x + ' ' + str(item.y)
52
- '''
53
-
54
- @staticmethod
55
- def method(func):
56
- setattr(func, _METHOD_INFO_ATTR, _MethodInfo(func))
57
- return func
58
-
59
- def set_output_context(self, prompt_tokens=None, completion_tokens=None):
60
- """This is used to set the prompt and completion tokens in the Output proto"""
61
- self._prompt_tokens = prompt_tokens
62
- self._completion_tokens = completion_tokens
63
-
64
- def load_model(self):
65
- """Load the model."""
66
-
67
- def _handle_get_signatures_request(self) -> service_pb2.MultiOutputResponse:
68
- methods = self._get_method_info()
69
- signatures = {method.name: method.signature for method in methods.values()}
70
- resp = service_pb2.MultiOutputResponse(status=status_pb2.Status(code=status_code_pb2.SUCCESS))
71
- output = resp.outputs.add()
72
- output.status.code = status_code_pb2.SUCCESS
73
- output.data.text.raw = signatures_to_json(signatures)
74
- return resp
75
-
76
- def _batch_predict(self, method, inputs: List[Dict[str, Any]]) -> List[Any]:
77
- """Batch predict method for multiple inputs."""
78
- outputs = []
79
- for input in inputs:
80
- output = method(**input)
81
- outputs.append(output)
82
- return outputs
83
-
84
- def _batch_generate(self, method, inputs: List[Dict[str, Any]]) -> Iterator[List[Any]]:
85
- """Batch generate method for multiple inputs."""
86
- generators = [method(**input) for input in inputs]
87
- for outputs in itertools.zip_longest(*generators):
88
- yield outputs
89
-
90
- def predict_wrapper(
91
- self, request: service_pb2.PostModelOutputsRequest) -> service_pb2.MultiOutputResponse:
92
- outputs = []
93
- try:
94
- # TODO add method name field to proto
95
- method_name = 'predict'
96
- inference_params = get_inference_params(request)
97
- if len(request.inputs) > 0 and '_method_name' in request.inputs[0].data.metadata:
98
- method_name = request.inputs[0].data.metadata['_method_name']
99
- if method_name == '_GET_SIGNATURES': # special case to fetch signatures, TODO add endpoint for this
100
- return self._handle_get_signatures_request()
101
- if method_name not in self._get_method_info():
102
- raise ValueError(f"Method {method_name} not found in model class")
103
- method = getattr(self, method_name)
104
- method_info = method._cf_method_info
105
- signature = method_info.signature
106
- python_param_types = method_info.python_param_types
107
- for input in request.inputs:
108
- # check if input is in old format
109
- is_convert = DataConverter.is_old_format(input.data)
110
- if is_convert:
111
- # convert to new format
112
- new_data = DataConverter.convert_input_data_to_new_format(input.data,
113
- signature.input_fields)
114
- input.data.CopyFrom(new_data)
115
- # convert inputs to python types
116
- inputs = self._convert_input_protos_to_python(request.inputs, inference_params,
117
- signature.input_fields, python_param_types)
118
- if len(inputs) == 1:
119
- inputs = inputs[0]
120
- output = method(**inputs)
121
- outputs.append(
122
- self._convert_output_to_proto(
123
- output, signature.output_fields, convert_old_format=is_convert))
124
- else:
125
- outputs = self._batch_predict(method, inputs)
126
- outputs = [
127
- self._convert_output_to_proto(
128
- output, signature.output_fields, convert_old_format=is_convert)
129
- for output in outputs
130
- ]
131
-
132
- return service_pb2.MultiOutputResponse(
133
- outputs=outputs, status=status_pb2.Status(code=status_code_pb2.SUCCESS))
134
- except Exception as e:
135
- if _RAISE_EXCEPTIONS:
136
- raise
137
- logging.exception("Error in predict")
138
- return service_pb2.MultiOutputResponse(status=status_pb2.Status(
139
- code=status_code_pb2.FAILURE,
140
- details=str(e),
141
- stack_trace=traceback.format_exc().split('\n')))
142
-
143
- def generate_wrapper(self, request: service_pb2.PostModelOutputsRequest
144
- ) -> Iterator[service_pb2.MultiOutputResponse]:
145
- try:
146
- method_name = 'generate'
147
- inference_params = get_inference_params(request)
148
- if len(request.inputs) > 0 and '_method_name' in request.inputs[0].data.metadata:
149
- method_name = request.inputs[0].data.metadata['_method_name']
150
- method = getattr(self, method_name)
151
- method_info = method._cf_method_info
152
- signature = method_info.signature
153
- python_param_types = method_info.python_param_types
154
- for input in request.inputs:
155
- # check if input is in old format
156
- is_convert = DataConverter.is_old_format(input.data)
157
- if is_convert:
158
- # convert to new format
159
- new_data = DataConverter.convert_input_data_to_new_format(input.data,
160
- signature.input_fields)
161
- input.data.CopyFrom(new_data)
162
- inputs = self._convert_input_protos_to_python(request.inputs, inference_params,
163
- signature.input_fields, python_param_types)
164
- if len(inputs) == 1:
165
- inputs = inputs[0]
166
- for output in method(**inputs):
167
- resp = service_pb2.MultiOutputResponse()
168
- self._convert_output_to_proto(
169
- output,
170
- signature.output_fields,
171
- proto=resp.outputs.add(),
172
- convert_old_format=is_convert)
173
- resp.status.code = status_code_pb2.SUCCESS
174
- yield resp
175
- else:
176
- for outputs in self._batch_generate(method, inputs):
177
- resp = service_pb2.MultiOutputResponse()
178
- for output in outputs:
179
- self._convert_output_to_proto(
180
- output,
181
- signature.output_fields,
182
- proto=resp.outputs.add(),
183
- convert_old_format=is_convert)
184
- resp.status.code = status_code_pb2.SUCCESS
185
- yield resp
186
- except Exception as e:
187
- if _RAISE_EXCEPTIONS:
188
- raise
189
- logging.exception("Error in generate")
190
- yield service_pb2.MultiOutputResponse(status=status_pb2.Status(
191
- code=status_code_pb2.FAILURE,
192
- details=str(e),
193
- stack_trace=traceback.format_exc().split('\n')))
194
-
195
- def stream_wrapper(self, request_iterator: Iterator[service_pb2.PostModelOutputsRequest]
196
- ) -> Iterator[service_pb2.MultiOutputResponse]:
197
- try:
198
- request = next(request_iterator) # get first request to determine method
199
- assert len(request.inputs) == 1, "Streaming requires exactly one input"
200
-
201
- method_name = 'stream'
202
- inference_params = get_inference_params(request)
203
- if len(request.inputs) > 0 and '_method_name' in request.inputs[0].data.metadata:
204
- method_name = request.inputs[0].data.metadata['_method_name']
205
- method = getattr(self, method_name)
206
- method_info = method._cf_method_info
207
- signature = method_info.signature
208
- python_param_types = method_info.python_param_types
209
-
210
- # find the streaming vars in the signature
211
- stream_sig = get_stream_from_signature(signature.input_fields)
212
- if stream_sig is None:
213
- raise ValueError("Streaming method must have a Stream input")
214
- stream_argname = stream_sig.name
215
-
216
- for input in request.inputs:
217
- # check if input is in old format
218
- is_convert = DataConverter.is_old_format(input.data)
219
- if is_convert:
220
- # convert to new format
221
- new_data = DataConverter.convert_input_data_to_new_format(input.data,
222
- signature.input_fields)
223
- input.data.CopyFrom(new_data)
224
- # convert all inputs for the first request, including the first stream value
225
- inputs = self._convert_input_protos_to_python(request.inputs, inference_params,
226
- signature.input_fields, python_param_types)
227
- kwargs = inputs[0]
228
-
229
- # first streaming item
230
- first_item = kwargs.pop(stream_argname)
231
-
232
- # streaming generator
233
- def InputStream():
234
- yield first_item
235
- # subsequent streaming items contain only the streaming input
236
- for request in request_iterator:
237
- item = self._convert_input_protos_to_python(request.inputs, inference_params,
238
- [stream_sig], python_param_types)
239
- item = item[0][stream_argname]
240
- yield item
241
-
242
- # add stream generator back to the input kwargs
243
- kwargs[stream_argname] = InputStream()
244
-
245
- for output in method(**kwargs):
246
- resp = service_pb2.MultiOutputResponse()
247
- self._convert_output_to_proto(
248
- output,
249
- signature.output_fields,
250
- proto=resp.outputs.add(),
251
- convert_old_format=is_convert)
252
- resp.status.code = status_code_pb2.SUCCESS
253
- yield resp
254
- except Exception as e:
255
- if _RAISE_EXCEPTIONS:
256
- raise
257
- logging.exception("Error in stream")
258
- yield service_pb2.MultiOutputResponse(status=status_pb2.Status(
259
- code=status_code_pb2.FAILURE,
260
- details=str(e),
261
- stack_trace=traceback.format_exc().split('\n')))
262
-
263
- def _convert_input_protos_to_python(self, inputs: List[resources_pb2.Input],
264
- inference_params: dict,
265
- variables_signature: List[resources_pb2.ModelTypeField],
266
- python_param_types) -> List[Dict[str, Any]]:
267
- result = []
268
- for input in inputs:
269
- kwargs = deserialize(input.data, variables_signature, inference_params)
270
- # dynamic cast to annotated types
271
- for k, v in kwargs.items():
272
- if k not in python_param_types:
273
- continue
274
-
275
- if hasattr(python_param_types[k],
276
- "__args__") and (getattr(python_param_types[k], "__origin__",
277
- None) in [abc.Iterator, abc.Generator, abc.Iterable]):
278
- # get the type of the items in the stream
279
- stream_type = python_param_types[k].__args__[0]
280
-
281
- kwargs[k] = data_types.cast(v, stream_type)
282
- else:
283
- kwargs[k] = data_types.cast(v, python_param_types[k])
284
- result.append(kwargs)
285
- return result
286
-
287
- def _convert_output_to_proto(self,
288
- output: Any,
289
- variables_signature: List[resources_pb2.ModelTypeField],
290
- proto=None,
291
- convert_old_format=False) -> resources_pb2.Output:
292
- if proto is None:
293
- proto = resources_pb2.Output()
294
- serialize({'return': output}, variables_signature, proto.data, is_output=True)
295
- if convert_old_format:
296
- # convert to old format
297
- data = DataConverter.convert_output_data_to_old_format(proto.data)
298
- proto.data.CopyFrom(data)
299
- proto.status.code = status_code_pb2.SUCCESS
300
- if hasattr(self, "_prompt_tokens") and self._prompt_tokens is not None:
301
- proto.prompt_tokens = self._prompt_tokens
302
- if hasattr(self, "_completion_tokens") and self._completion_tokens is not None:
303
- proto.completion_tokens = self._completion_tokens
304
- self._prompt_tokens = None
305
- self._completion_tokens = None
306
- return proto
307
-
308
- @classmethod
309
- def _register_model_methods(cls):
310
- # go up the class hierarchy to find all decorated methods, and add to registry of current class
311
- methods = {}
312
- for base in reversed(cls.__mro__):
313
- for name, method in base.__dict__.items():
314
- method_info = getattr(method, _METHOD_INFO_ATTR, None)
315
- if not method_info: # regular function, not a model method
316
- continue
317
- methods[name] = method_info
318
- # check for generic predict(request) -> response, etc. methods
319
- #for name in ('predict', 'generate', 'stream'):
320
- # if hasattr(cls, name):
321
- # method = getattr(cls, name)
322
- # if not hasattr(method, _METHOD_INFO_ATTR): # not already put in registry
323
- # methods[name] = _MethodInfo(method)
324
- # set method table for this class in the registry
325
- return methods
326
-
327
- @classmethod
328
- def _get_method_info(cls, func_name=None):
329
- if not hasattr(cls, _METHOD_INFO_ATTR):
330
- setattr(cls, _METHOD_INFO_ATTR, cls._register_model_methods())
331
- method_info = getattr(cls, _METHOD_INFO_ATTR)
332
- if func_name:
333
- return method_info[func_name]
334
- return method_info
30
+ '''
31
+ Base class for model classes that can be run as a service.
32
+
33
+ Define predict, generate, or stream methods using the @ModelClass.method decorator.
34
+
35
+ Example:
36
+
37
+ from clarifai.runners.model_class import ModelClass
38
+ from clarifai.runners.utils.data_types import NamedFields
39
+ from typing import List, Iterator
40
+
41
+ class MyModel(ModelClass):
42
+
43
+ @ModelClass.method
44
+ def predict(self, x: str, y: int) -> List[str]:
45
+ return [x] * y
46
+
47
+ @ModelClass.method
48
+ def generate(self, x: str, y: int) -> Iterator[str]:
49
+ for i in range(y):
50
+ yield x + str(i)
51
+
52
+ @ModelClass.method
53
+ def stream(self, input_stream: Iterator[NamedFields(x=str, y=int)]) -> Iterator[str]:
54
+ for item in input_stream:
55
+ yield item.x + ' ' + str(item.y)
56
+ '''
57
+
58
+ @staticmethod
59
+ def method(func):
60
+ setattr(func, _METHOD_INFO_ATTR, _MethodInfo(func))
61
+ return func
62
+
63
+ def set_output_context(self, prompt_tokens=None, completion_tokens=None):
64
+ """This is used to set the prompt and completion tokens in the Output proto"""
65
+ self._prompt_tokens = prompt_tokens
66
+ self._completion_tokens = completion_tokens
67
+
68
+ def load_model(self):
69
+ """Load the model."""
70
+
71
+ def _handle_get_signatures_request(self) -> service_pb2.MultiOutputResponse:
72
+ methods = self._get_method_info()
73
+ signatures = {method.name: method.signature for method in methods.values()}
74
+ resp = service_pb2.MultiOutputResponse(
75
+ status=status_pb2.Status(code=status_code_pb2.SUCCESS)
76
+ )
77
+ output = resp.outputs.add()
78
+ output.status.code = status_code_pb2.SUCCESS
79
+ output.data.text.raw = signatures_to_json(signatures)
80
+ return resp
81
+
82
+ def _batch_predict(self, method, inputs: List[Dict[str, Any]]) -> List[Any]:
83
+ """Batch predict method for multiple inputs."""
84
+ outputs = []
85
+ for input in inputs:
86
+ output = method(**input)
87
+ outputs.append(output)
88
+ return outputs
89
+
90
+ def _batch_generate(self, method, inputs: List[Dict[str, Any]]) -> Iterator[List[Any]]:
91
+ """Batch generate method for multiple inputs."""
92
+ generators = [method(**input) for input in inputs]
93
+ for outputs in itertools.zip_longest(*generators):
94
+ yield outputs
95
+
96
+ def predict_wrapper(
97
+ self, request: service_pb2.PostModelOutputsRequest
98
+ ) -> service_pb2.MultiOutputResponse:
99
+ outputs = []
100
+ try:
101
+ # TODO add method name field to proto
102
+ method_name = 'predict'
103
+ inference_params = get_inference_params(request)
104
+ if len(request.inputs) > 0 and '_method_name' in request.inputs[0].data.metadata:
105
+ method_name = request.inputs[0].data.metadata['_method_name']
106
+ if (
107
+ method_name == '_GET_SIGNATURES'
108
+ ): # special case to fetch signatures, TODO add endpoint for this
109
+ return self._handle_get_signatures_request()
110
+ if method_name not in self._get_method_info():
111
+ raise ValueError(f"Method {method_name} not found in model class")
112
+ method = getattr(self, method_name)
113
+ method_info = method._cf_method_info
114
+ signature = method_info.signature
115
+ python_param_types = method_info.python_param_types
116
+ for input in request.inputs:
117
+ # check if input is in old format
118
+ is_convert = DataConverter.is_old_format(input.data)
119
+ if is_convert:
120
+ # convert to new format
121
+ new_data = DataConverter.convert_input_data_to_new_format(
122
+ input.data, signature.input_fields
123
+ )
124
+ input.data.CopyFrom(new_data)
125
+ # convert inputs to python types
126
+ inputs = self._convert_input_protos_to_python(
127
+ request.inputs, inference_params, signature.input_fields, python_param_types
128
+ )
129
+ if len(inputs) == 1:
130
+ inputs = inputs[0]
131
+ output = method(**inputs)
132
+ outputs.append(
133
+ self._convert_output_to_proto(
134
+ output, signature.output_fields, convert_old_format=is_convert
135
+ )
136
+ )
137
+ else:
138
+ outputs = self._batch_predict(method, inputs)
139
+ outputs = [
140
+ self._convert_output_to_proto(
141
+ output, signature.output_fields, convert_old_format=is_convert
142
+ )
143
+ for output in outputs
144
+ ]
145
+
146
+ return service_pb2.MultiOutputResponse(
147
+ outputs=outputs, status=status_pb2.Status(code=status_code_pb2.SUCCESS)
148
+ )
149
+ except Exception as e:
150
+ if _RAISE_EXCEPTIONS:
151
+ raise
152
+ logging.exception("Error in predict")
153
+ return service_pb2.MultiOutputResponse(
154
+ status=status_pb2.Status(
155
+ code=status_code_pb2.FAILURE,
156
+ details=str(e),
157
+ stack_trace=traceback.format_exc().split('\n'),
158
+ )
159
+ )
160
+
161
+ def generate_wrapper(
162
+ self, request: service_pb2.PostModelOutputsRequest
163
+ ) -> Iterator[service_pb2.MultiOutputResponse]:
164
+ try:
165
+ method_name = 'generate'
166
+ inference_params = get_inference_params(request)
167
+ if len(request.inputs) > 0 and '_method_name' in request.inputs[0].data.metadata:
168
+ method_name = request.inputs[0].data.metadata['_method_name']
169
+ method = getattr(self, method_name)
170
+ method_info = method._cf_method_info
171
+ signature = method_info.signature
172
+ python_param_types = method_info.python_param_types
173
+ for input in request.inputs:
174
+ # check if input is in old format
175
+ is_convert = DataConverter.is_old_format(input.data)
176
+ if is_convert:
177
+ # convert to new format
178
+ new_data = DataConverter.convert_input_data_to_new_format(
179
+ input.data, signature.input_fields
180
+ )
181
+ input.data.CopyFrom(new_data)
182
+ inputs = self._convert_input_protos_to_python(
183
+ request.inputs, inference_params, signature.input_fields, python_param_types
184
+ )
185
+ if len(inputs) == 1:
186
+ inputs = inputs[0]
187
+ for output in method(**inputs):
188
+ resp = service_pb2.MultiOutputResponse()
189
+ self._convert_output_to_proto(
190
+ output,
191
+ signature.output_fields,
192
+ proto=resp.outputs.add(),
193
+ convert_old_format=is_convert,
194
+ )
195
+ resp.status.code = status_code_pb2.SUCCESS
196
+ yield resp
197
+ else:
198
+ for outputs in self._batch_generate(method, inputs):
199
+ resp = service_pb2.MultiOutputResponse()
200
+ for output in outputs:
201
+ self._convert_output_to_proto(
202
+ output,
203
+ signature.output_fields,
204
+ proto=resp.outputs.add(),
205
+ convert_old_format=is_convert,
206
+ )
207
+ resp.status.code = status_code_pb2.SUCCESS
208
+ yield resp
209
+ except Exception as e:
210
+ if _RAISE_EXCEPTIONS:
211
+ raise
212
+ logging.exception("Error in generate")
213
+ yield service_pb2.MultiOutputResponse(
214
+ status=status_pb2.Status(
215
+ code=status_code_pb2.FAILURE,
216
+ details=str(e),
217
+ stack_trace=traceback.format_exc().split('\n'),
218
+ )
219
+ )
220
+
221
+ def stream_wrapper(
222
+ self, request_iterator: Iterator[service_pb2.PostModelOutputsRequest]
223
+ ) -> Iterator[service_pb2.MultiOutputResponse]:
224
+ try:
225
+ request = next(request_iterator) # get first request to determine method
226
+ assert len(request.inputs) == 1, "Streaming requires exactly one input"
227
+
228
+ method_name = 'stream'
229
+ inference_params = get_inference_params(request)
230
+ if len(request.inputs) > 0 and '_method_name' in request.inputs[0].data.metadata:
231
+ method_name = request.inputs[0].data.metadata['_method_name']
232
+ method = getattr(self, method_name)
233
+ method_info = method._cf_method_info
234
+ signature = method_info.signature
235
+ python_param_types = method_info.python_param_types
236
+
237
+ # find the streaming vars in the signature
238
+ stream_sig = get_stream_from_signature(signature.input_fields)
239
+ if stream_sig is None:
240
+ raise ValueError("Streaming method must have a Stream input")
241
+ stream_argname = stream_sig.name
242
+
243
+ for input in request.inputs:
244
+ # check if input is in old format
245
+ is_convert = DataConverter.is_old_format(input.data)
246
+ if is_convert:
247
+ # convert to new format
248
+ new_data = DataConverter.convert_input_data_to_new_format(
249
+ input.data, signature.input_fields
250
+ )
251
+ input.data.CopyFrom(new_data)
252
+ # convert all inputs for the first request, including the first stream value
253
+ inputs = self._convert_input_protos_to_python(
254
+ request.inputs, inference_params, signature.input_fields, python_param_types
255
+ )
256
+ kwargs = inputs[0]
257
+
258
+ # first streaming item
259
+ first_item = kwargs.pop(stream_argname)
260
+
261
+ # streaming generator
262
+ def InputStream():
263
+ yield first_item
264
+ # subsequent streaming items contain only the streaming input
265
+ for request in request_iterator:
266
+ item = self._convert_input_protos_to_python(
267
+ request.inputs, inference_params, [stream_sig], python_param_types
268
+ )
269
+ item = item[0][stream_argname]
270
+ yield item
271
+
272
+ # add stream generator back to the input kwargs
273
+ kwargs[stream_argname] = InputStream()
274
+
275
+ for output in method(**kwargs):
276
+ resp = service_pb2.MultiOutputResponse()
277
+ self._convert_output_to_proto(
278
+ output,
279
+ signature.output_fields,
280
+ proto=resp.outputs.add(),
281
+ convert_old_format=is_convert,
282
+ )
283
+ resp.status.code = status_code_pb2.SUCCESS
284
+ yield resp
285
+ except Exception as e:
286
+ if _RAISE_EXCEPTIONS:
287
+ raise
288
+ logging.exception("Error in stream")
289
+ yield service_pb2.MultiOutputResponse(
290
+ status=status_pb2.Status(
291
+ code=status_code_pb2.FAILURE,
292
+ details=str(e),
293
+ stack_trace=traceback.format_exc().split('\n'),
294
+ )
295
+ )
296
+
297
+ def _convert_input_protos_to_python(
298
+ self,
299
+ inputs: List[resources_pb2.Input],
300
+ inference_params: dict,
301
+ variables_signature: List[resources_pb2.ModelTypeField],
302
+ python_param_types,
303
+ ) -> List[Dict[str, Any]]:
304
+ result = []
305
+ for input in inputs:
306
+ kwargs = deserialize(input.data, variables_signature, inference_params)
307
+ # dynamic cast to annotated types
308
+ for k, v in kwargs.items():
309
+ if k not in python_param_types:
310
+ continue
311
+
312
+ if hasattr(python_param_types[k], "__args__") and (
313
+ getattr(python_param_types[k], "__origin__", None)
314
+ in [abc.Iterator, abc.Generator, abc.Iterable]
315
+ ):
316
+ # get the type of the items in the stream
317
+ stream_type = python_param_types[k].__args__[0]
318
+
319
+ kwargs[k] = data_types.cast(v, stream_type)
320
+ else:
321
+ kwargs[k] = data_types.cast(v, python_param_types[k])
322
+ result.append(kwargs)
323
+ return result
324
+
325
+ def _convert_output_to_proto(
326
+ self,
327
+ output: Any,
328
+ variables_signature: List[resources_pb2.ModelTypeField],
329
+ proto=None,
330
+ convert_old_format=False,
331
+ ) -> resources_pb2.Output:
332
+ if proto is None:
333
+ proto = resources_pb2.Output()
334
+ serialize({'return': output}, variables_signature, proto.data, is_output=True)
335
+ if convert_old_format:
336
+ # convert to old format
337
+ data = DataConverter.convert_output_data_to_old_format(proto.data)
338
+ proto.data.CopyFrom(data)
339
+ proto.status.code = status_code_pb2.SUCCESS
340
+ if hasattr(self, "_prompt_tokens") and self._prompt_tokens is not None:
341
+ proto.prompt_tokens = self._prompt_tokens
342
+ if hasattr(self, "_completion_tokens") and self._completion_tokens is not None:
343
+ proto.completion_tokens = self._completion_tokens
344
+ self._prompt_tokens = None
345
+ self._completion_tokens = None
346
+ return proto
347
+
348
+ @classmethod
349
+ def _register_model_methods(cls):
350
+ # go up the class hierarchy to find all decorated methods, and add to registry of current class
351
+ methods = {}
352
+ for base in reversed(cls.__mro__):
353
+ for name, method in base.__dict__.items():
354
+ method_info = getattr(method, _METHOD_INFO_ATTR, None)
355
+ if not method_info: # regular function, not a model method
356
+ continue
357
+ methods[name] = method_info
358
+ # check for generic predict(request) -> response, etc. methods
359
+ # for name in ('predict', 'generate', 'stream'):
360
+ # if hasattr(cls, name):
361
+ # method = getattr(cls, name)
362
+ # if not hasattr(method, _METHOD_INFO_ATTR): # not already put in registry
363
+ # methods[name] = _MethodInfo(method)
364
+ # set method table for this class in the registry
365
+ return methods
366
+
367
+ @classmethod
368
+ def _get_method_info(cls, func_name=None):
369
+ if not hasattr(cls, _METHOD_INFO_ATTR):
370
+ setattr(cls, _METHOD_INFO_ATTR, cls._register_model_methods())
371
+ method_info = getattr(cls, _METHOD_INFO_ATTR)
372
+ if func_name:
373
+ return method_info[func_name]
374
+ return method_info
335
375
 
336
376
 
337
377
  # Helper function to get the inference params
338
378
  def get_inference_params(request) -> dict:
339
- """Get the inference params from the request."""
340
- inference_params = {}
341
- if request.model.model_version.id != "":
342
- output_info = request.model.model_version.output_info
343
- output_info = json_format.MessageToDict(output_info, preserving_proto_field_name=True)
344
- if "params" in output_info:
345
- inference_params = output_info["params"]
346
- return inference_params
379
+ """Get the inference params from the request."""
380
+ inference_params = {}
381
+ if request.model.model_version.id != "":
382
+ output_info = request.model.model_version.output_info
383
+ output_info = json_format.MessageToDict(output_info, preserving_proto_field_name=True)
384
+ if "params" in output_info:
385
+ inference_params = output_info["params"]
386
+ return inference_params
347
387
 
348
388
 
349
389
  class _MethodInfo:
350
-
351
- def __init__(self, method):
352
- self.name = method.__name__
353
- self.signature = build_function_signature(method)
354
- self.python_param_types = {
355
- p.name: p.annotation
356
- for p in inspect.signature(method).parameters.values()
357
- if p.annotation != inspect.Parameter.empty
358
- }
359
- self.python_param_types.pop('self', None)
390
+ def __init__(self, method):
391
+ self.name = method.__name__
392
+ self.signature = build_function_signature(method)
393
+ self.python_param_types = {
394
+ p.name: p.annotation
395
+ for p in inspect.signature(method).parameters.values()
396
+ if p.annotation != inspect.Parameter.empty
397
+ }
398
+ self.python_param_types.pop('self', None)