clarifai 11.3.0rc2__py3-none-any.whl → 11.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (300) hide show
  1. clarifai/__init__.py +1 -1
  2. clarifai/cli/__main__.py +1 -1
  3. clarifai/cli/base.py +144 -136
  4. clarifai/cli/compute_cluster.py +45 -31
  5. clarifai/cli/deployment.py +93 -76
  6. clarifai/cli/model.py +578 -180
  7. clarifai/cli/nodepool.py +100 -82
  8. clarifai/client/__init__.py +12 -2
  9. clarifai/client/app.py +973 -911
  10. clarifai/client/auth/helper.py +345 -342
  11. clarifai/client/auth/register.py +7 -7
  12. clarifai/client/auth/stub.py +107 -106
  13. clarifai/client/base.py +185 -178
  14. clarifai/client/compute_cluster.py +214 -180
  15. clarifai/client/dataset.py +793 -698
  16. clarifai/client/deployment.py +55 -50
  17. clarifai/client/input.py +1223 -1088
  18. clarifai/client/lister.py +47 -45
  19. clarifai/client/model.py +1939 -1717
  20. clarifai/client/model_client.py +525 -502
  21. clarifai/client/module.py +82 -73
  22. clarifai/client/nodepool.py +358 -213
  23. clarifai/client/runner.py +58 -0
  24. clarifai/client/search.py +342 -309
  25. clarifai/client/user.py +419 -414
  26. clarifai/client/workflow.py +294 -274
  27. clarifai/constants/dataset.py +11 -17
  28. clarifai/constants/model.py +8 -2
  29. clarifai/datasets/export/inputs_annotations.py +233 -217
  30. clarifai/datasets/upload/base.py +63 -51
  31. clarifai/datasets/upload/features.py +43 -38
  32. clarifai/datasets/upload/image.py +237 -207
  33. clarifai/datasets/upload/loaders/coco_captions.py +34 -32
  34. clarifai/datasets/upload/loaders/coco_detection.py +72 -65
  35. clarifai/datasets/upload/loaders/imagenet_classification.py +57 -53
  36. clarifai/datasets/upload/loaders/xview_detection.py +274 -132
  37. clarifai/datasets/upload/multimodal.py +55 -46
  38. clarifai/datasets/upload/text.py +55 -47
  39. clarifai/datasets/upload/utils.py +250 -234
  40. clarifai/errors.py +51 -50
  41. clarifai/models/api.py +260 -238
  42. clarifai/modules/css.py +50 -50
  43. clarifai/modules/pages.py +33 -33
  44. clarifai/rag/rag.py +312 -288
  45. clarifai/rag/utils.py +91 -84
  46. clarifai/runners/models/model_builder.py +906 -802
  47. clarifai/runners/models/model_class.py +370 -331
  48. clarifai/runners/models/model_run_locally.py +459 -419
  49. clarifai/runners/models/model_runner.py +170 -162
  50. clarifai/runners/models/model_servicer.py +78 -70
  51. clarifai/runners/server.py +111 -101
  52. clarifai/runners/utils/code_script.py +225 -187
  53. clarifai/runners/utils/const.py +4 -1
  54. clarifai/runners/utils/data_types/__init__.py +12 -0
  55. clarifai/runners/utils/data_types/data_types.py +598 -0
  56. clarifai/runners/utils/data_utils.py +387 -440
  57. clarifai/runners/utils/loader.py +247 -227
  58. clarifai/runners/utils/method_signatures.py +411 -386
  59. clarifai/runners/utils/openai_convertor.py +108 -109
  60. clarifai/runners/utils/serializers.py +175 -179
  61. clarifai/runners/utils/url_fetcher.py +35 -35
  62. clarifai/schema/search.py +56 -63
  63. clarifai/urls/helper.py +125 -102
  64. clarifai/utils/cli.py +129 -123
  65. clarifai/utils/config.py +127 -87
  66. clarifai/utils/constants.py +49 -0
  67. clarifai/utils/evaluation/helpers.py +503 -466
  68. clarifai/utils/evaluation/main.py +431 -393
  69. clarifai/utils/evaluation/testset_annotation_parser.py +154 -144
  70. clarifai/utils/logging.py +324 -306
  71. clarifai/utils/misc.py +60 -56
  72. clarifai/utils/model_train.py +165 -146
  73. clarifai/utils/protobuf.py +126 -103
  74. clarifai/versions.py +3 -1
  75. clarifai/workflows/export.py +48 -50
  76. clarifai/workflows/utils.py +39 -36
  77. clarifai/workflows/validate.py +55 -43
  78. {clarifai-11.3.0rc2.dist-info → clarifai-11.4.0.dist-info}/METADATA +16 -6
  79. clarifai-11.4.0.dist-info/RECORD +109 -0
  80. {clarifai-11.3.0rc2.dist-info → clarifai-11.4.0.dist-info}/WHEEL +1 -1
  81. clarifai/__pycache__/__init__.cpython-310.pyc +0 -0
  82. clarifai/__pycache__/__init__.cpython-311.pyc +0 -0
  83. clarifai/__pycache__/__init__.cpython-39.pyc +0 -0
  84. clarifai/__pycache__/errors.cpython-310.pyc +0 -0
  85. clarifai/__pycache__/errors.cpython-311.pyc +0 -0
  86. clarifai/__pycache__/versions.cpython-310.pyc +0 -0
  87. clarifai/__pycache__/versions.cpython-311.pyc +0 -0
  88. clarifai/cli/__pycache__/__init__.cpython-310.pyc +0 -0
  89. clarifai/cli/__pycache__/__init__.cpython-311.pyc +0 -0
  90. clarifai/cli/__pycache__/base.cpython-310.pyc +0 -0
  91. clarifai/cli/__pycache__/base.cpython-311.pyc +0 -0
  92. clarifai/cli/__pycache__/base_cli.cpython-310.pyc +0 -0
  93. clarifai/cli/__pycache__/compute_cluster.cpython-310.pyc +0 -0
  94. clarifai/cli/__pycache__/compute_cluster.cpython-311.pyc +0 -0
  95. clarifai/cli/__pycache__/deployment.cpython-310.pyc +0 -0
  96. clarifai/cli/__pycache__/deployment.cpython-311.pyc +0 -0
  97. clarifai/cli/__pycache__/model.cpython-310.pyc +0 -0
  98. clarifai/cli/__pycache__/model.cpython-311.pyc +0 -0
  99. clarifai/cli/__pycache__/model_cli.cpython-310.pyc +0 -0
  100. clarifai/cli/__pycache__/nodepool.cpython-310.pyc +0 -0
  101. clarifai/cli/__pycache__/nodepool.cpython-311.pyc +0 -0
  102. clarifai/client/__pycache__/__init__.cpython-310.pyc +0 -0
  103. clarifai/client/__pycache__/__init__.cpython-311.pyc +0 -0
  104. clarifai/client/__pycache__/__init__.cpython-39.pyc +0 -0
  105. clarifai/client/__pycache__/app.cpython-310.pyc +0 -0
  106. clarifai/client/__pycache__/app.cpython-311.pyc +0 -0
  107. clarifai/client/__pycache__/app.cpython-39.pyc +0 -0
  108. clarifai/client/__pycache__/base.cpython-310.pyc +0 -0
  109. clarifai/client/__pycache__/base.cpython-311.pyc +0 -0
  110. clarifai/client/__pycache__/compute_cluster.cpython-310.pyc +0 -0
  111. clarifai/client/__pycache__/compute_cluster.cpython-311.pyc +0 -0
  112. clarifai/client/__pycache__/dataset.cpython-310.pyc +0 -0
  113. clarifai/client/__pycache__/dataset.cpython-311.pyc +0 -0
  114. clarifai/client/__pycache__/deployment.cpython-310.pyc +0 -0
  115. clarifai/client/__pycache__/deployment.cpython-311.pyc +0 -0
  116. clarifai/client/__pycache__/input.cpython-310.pyc +0 -0
  117. clarifai/client/__pycache__/input.cpython-311.pyc +0 -0
  118. clarifai/client/__pycache__/lister.cpython-310.pyc +0 -0
  119. clarifai/client/__pycache__/lister.cpython-311.pyc +0 -0
  120. clarifai/client/__pycache__/model.cpython-310.pyc +0 -0
  121. clarifai/client/__pycache__/model.cpython-311.pyc +0 -0
  122. clarifai/client/__pycache__/module.cpython-310.pyc +0 -0
  123. clarifai/client/__pycache__/module.cpython-311.pyc +0 -0
  124. clarifai/client/__pycache__/nodepool.cpython-310.pyc +0 -0
  125. clarifai/client/__pycache__/nodepool.cpython-311.pyc +0 -0
  126. clarifai/client/__pycache__/search.cpython-310.pyc +0 -0
  127. clarifai/client/__pycache__/search.cpython-311.pyc +0 -0
  128. clarifai/client/__pycache__/user.cpython-310.pyc +0 -0
  129. clarifai/client/__pycache__/user.cpython-311.pyc +0 -0
  130. clarifai/client/__pycache__/workflow.cpython-310.pyc +0 -0
  131. clarifai/client/__pycache__/workflow.cpython-311.pyc +0 -0
  132. clarifai/client/auth/__pycache__/__init__.cpython-310.pyc +0 -0
  133. clarifai/client/auth/__pycache__/__init__.cpython-311.pyc +0 -0
  134. clarifai/client/auth/__pycache__/helper.cpython-310.pyc +0 -0
  135. clarifai/client/auth/__pycache__/helper.cpython-311.pyc +0 -0
  136. clarifai/client/auth/__pycache__/register.cpython-310.pyc +0 -0
  137. clarifai/client/auth/__pycache__/register.cpython-311.pyc +0 -0
  138. clarifai/client/auth/__pycache__/stub.cpython-310.pyc +0 -0
  139. clarifai/client/auth/__pycache__/stub.cpython-311.pyc +0 -0
  140. clarifai/client/cli/__init__.py +0 -0
  141. clarifai/client/cli/__pycache__/__init__.cpython-310.pyc +0 -0
  142. clarifai/client/cli/__pycache__/base_cli.cpython-310.pyc +0 -0
  143. clarifai/client/cli/__pycache__/model_cli.cpython-310.pyc +0 -0
  144. clarifai/client/cli/base_cli.py +0 -88
  145. clarifai/client/cli/model_cli.py +0 -29
  146. clarifai/constants/__pycache__/base.cpython-310.pyc +0 -0
  147. clarifai/constants/__pycache__/base.cpython-311.pyc +0 -0
  148. clarifai/constants/__pycache__/dataset.cpython-310.pyc +0 -0
  149. clarifai/constants/__pycache__/dataset.cpython-311.pyc +0 -0
  150. clarifai/constants/__pycache__/input.cpython-310.pyc +0 -0
  151. clarifai/constants/__pycache__/input.cpython-311.pyc +0 -0
  152. clarifai/constants/__pycache__/model.cpython-310.pyc +0 -0
  153. clarifai/constants/__pycache__/model.cpython-311.pyc +0 -0
  154. clarifai/constants/__pycache__/rag.cpython-310.pyc +0 -0
  155. clarifai/constants/__pycache__/rag.cpython-311.pyc +0 -0
  156. clarifai/constants/__pycache__/search.cpython-310.pyc +0 -0
  157. clarifai/constants/__pycache__/search.cpython-311.pyc +0 -0
  158. clarifai/constants/__pycache__/workflow.cpython-310.pyc +0 -0
  159. clarifai/constants/__pycache__/workflow.cpython-311.pyc +0 -0
  160. clarifai/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
  161. clarifai/datasets/__pycache__/__init__.cpython-311.pyc +0 -0
  162. clarifai/datasets/__pycache__/__init__.cpython-39.pyc +0 -0
  163. clarifai/datasets/export/__pycache__/__init__.cpython-310.pyc +0 -0
  164. clarifai/datasets/export/__pycache__/__init__.cpython-311.pyc +0 -0
  165. clarifai/datasets/export/__pycache__/__init__.cpython-39.pyc +0 -0
  166. clarifai/datasets/export/__pycache__/inputs_annotations.cpython-310.pyc +0 -0
  167. clarifai/datasets/export/__pycache__/inputs_annotations.cpython-311.pyc +0 -0
  168. clarifai/datasets/upload/__pycache__/__init__.cpython-310.pyc +0 -0
  169. clarifai/datasets/upload/__pycache__/__init__.cpython-311.pyc +0 -0
  170. clarifai/datasets/upload/__pycache__/__init__.cpython-39.pyc +0 -0
  171. clarifai/datasets/upload/__pycache__/base.cpython-310.pyc +0 -0
  172. clarifai/datasets/upload/__pycache__/base.cpython-311.pyc +0 -0
  173. clarifai/datasets/upload/__pycache__/features.cpython-310.pyc +0 -0
  174. clarifai/datasets/upload/__pycache__/features.cpython-311.pyc +0 -0
  175. clarifai/datasets/upload/__pycache__/image.cpython-310.pyc +0 -0
  176. clarifai/datasets/upload/__pycache__/image.cpython-311.pyc +0 -0
  177. clarifai/datasets/upload/__pycache__/multimodal.cpython-310.pyc +0 -0
  178. clarifai/datasets/upload/__pycache__/multimodal.cpython-311.pyc +0 -0
  179. clarifai/datasets/upload/__pycache__/text.cpython-310.pyc +0 -0
  180. clarifai/datasets/upload/__pycache__/text.cpython-311.pyc +0 -0
  181. clarifai/datasets/upload/__pycache__/utils.cpython-310.pyc +0 -0
  182. clarifai/datasets/upload/__pycache__/utils.cpython-311.pyc +0 -0
  183. clarifai/datasets/upload/loaders/__pycache__/__init__.cpython-311.pyc +0 -0
  184. clarifai/datasets/upload/loaders/__pycache__/__init__.cpython-39.pyc +0 -0
  185. clarifai/datasets/upload/loaders/__pycache__/coco_detection.cpython-311.pyc +0 -0
  186. clarifai/datasets/upload/loaders/__pycache__/imagenet_classification.cpython-311.pyc +0 -0
  187. clarifai/models/__pycache__/__init__.cpython-39.pyc +0 -0
  188. clarifai/modules/__pycache__/__init__.cpython-39.pyc +0 -0
  189. clarifai/rag/__pycache__/__init__.cpython-310.pyc +0 -0
  190. clarifai/rag/__pycache__/__init__.cpython-311.pyc +0 -0
  191. clarifai/rag/__pycache__/__init__.cpython-39.pyc +0 -0
  192. clarifai/rag/__pycache__/rag.cpython-310.pyc +0 -0
  193. clarifai/rag/__pycache__/rag.cpython-311.pyc +0 -0
  194. clarifai/rag/__pycache__/rag.cpython-39.pyc +0 -0
  195. clarifai/rag/__pycache__/utils.cpython-310.pyc +0 -0
  196. clarifai/rag/__pycache__/utils.cpython-311.pyc +0 -0
  197. clarifai/runners/__pycache__/__init__.cpython-310.pyc +0 -0
  198. clarifai/runners/__pycache__/__init__.cpython-311.pyc +0 -0
  199. clarifai/runners/__pycache__/__init__.cpython-39.pyc +0 -0
  200. clarifai/runners/dockerfile_template/Dockerfile.cpu.template +0 -31
  201. clarifai/runners/dockerfile_template/Dockerfile.cuda.template +0 -42
  202. clarifai/runners/dockerfile_template/Dockerfile.nim +0 -71
  203. clarifai/runners/models/__pycache__/__init__.cpython-310.pyc +0 -0
  204. clarifai/runners/models/__pycache__/__init__.cpython-311.pyc +0 -0
  205. clarifai/runners/models/__pycache__/__init__.cpython-39.pyc +0 -0
  206. clarifai/runners/models/__pycache__/base_typed_model.cpython-310.pyc +0 -0
  207. clarifai/runners/models/__pycache__/base_typed_model.cpython-311.pyc +0 -0
  208. clarifai/runners/models/__pycache__/base_typed_model.cpython-39.pyc +0 -0
  209. clarifai/runners/models/__pycache__/model_builder.cpython-311.pyc +0 -0
  210. clarifai/runners/models/__pycache__/model_class.cpython-310.pyc +0 -0
  211. clarifai/runners/models/__pycache__/model_class.cpython-311.pyc +0 -0
  212. clarifai/runners/models/__pycache__/model_run_locally.cpython-310-pytest-7.1.2.pyc +0 -0
  213. clarifai/runners/models/__pycache__/model_run_locally.cpython-310.pyc +0 -0
  214. clarifai/runners/models/__pycache__/model_run_locally.cpython-311.pyc +0 -0
  215. clarifai/runners/models/__pycache__/model_runner.cpython-310.pyc +0 -0
  216. clarifai/runners/models/__pycache__/model_runner.cpython-311.pyc +0 -0
  217. clarifai/runners/models/__pycache__/model_upload.cpython-310.pyc +0 -0
  218. clarifai/runners/models/base_typed_model.py +0 -238
  219. clarifai/runners/models/model_class_refract.py +0 -80
  220. clarifai/runners/models/model_upload.py +0 -607
  221. clarifai/runners/models/temp.py +0 -25
  222. clarifai/runners/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  223. clarifai/runners/utils/__pycache__/__init__.cpython-311.pyc +0 -0
  224. clarifai/runners/utils/__pycache__/__init__.cpython-38.pyc +0 -0
  225. clarifai/runners/utils/__pycache__/__init__.cpython-39.pyc +0 -0
  226. clarifai/runners/utils/__pycache__/buffered_stream.cpython-310.pyc +0 -0
  227. clarifai/runners/utils/__pycache__/buffered_stream.cpython-38.pyc +0 -0
  228. clarifai/runners/utils/__pycache__/buffered_stream.cpython-39.pyc +0 -0
  229. clarifai/runners/utils/__pycache__/const.cpython-310.pyc +0 -0
  230. clarifai/runners/utils/__pycache__/const.cpython-311.pyc +0 -0
  231. clarifai/runners/utils/__pycache__/constants.cpython-310.pyc +0 -0
  232. clarifai/runners/utils/__pycache__/constants.cpython-38.pyc +0 -0
  233. clarifai/runners/utils/__pycache__/constants.cpython-39.pyc +0 -0
  234. clarifai/runners/utils/__pycache__/data_handler.cpython-310.pyc +0 -0
  235. clarifai/runners/utils/__pycache__/data_handler.cpython-311.pyc +0 -0
  236. clarifai/runners/utils/__pycache__/data_handler.cpython-38.pyc +0 -0
  237. clarifai/runners/utils/__pycache__/data_handler.cpython-39.pyc +0 -0
  238. clarifai/runners/utils/__pycache__/data_utils.cpython-310.pyc +0 -0
  239. clarifai/runners/utils/__pycache__/data_utils.cpython-311.pyc +0 -0
  240. clarifai/runners/utils/__pycache__/data_utils.cpython-38.pyc +0 -0
  241. clarifai/runners/utils/__pycache__/data_utils.cpython-39.pyc +0 -0
  242. clarifai/runners/utils/__pycache__/grpc_server.cpython-310.pyc +0 -0
  243. clarifai/runners/utils/__pycache__/grpc_server.cpython-38.pyc +0 -0
  244. clarifai/runners/utils/__pycache__/grpc_server.cpython-39.pyc +0 -0
  245. clarifai/runners/utils/__pycache__/health.cpython-310.pyc +0 -0
  246. clarifai/runners/utils/__pycache__/health.cpython-38.pyc +0 -0
  247. clarifai/runners/utils/__pycache__/health.cpython-39.pyc +0 -0
  248. clarifai/runners/utils/__pycache__/loader.cpython-310.pyc +0 -0
  249. clarifai/runners/utils/__pycache__/loader.cpython-311.pyc +0 -0
  250. clarifai/runners/utils/__pycache__/logging.cpython-310.pyc +0 -0
  251. clarifai/runners/utils/__pycache__/logging.cpython-38.pyc +0 -0
  252. clarifai/runners/utils/__pycache__/logging.cpython-39.pyc +0 -0
  253. clarifai/runners/utils/__pycache__/stream_source.cpython-310.pyc +0 -0
  254. clarifai/runners/utils/__pycache__/stream_source.cpython-39.pyc +0 -0
  255. clarifai/runners/utils/__pycache__/url_fetcher.cpython-310.pyc +0 -0
  256. clarifai/runners/utils/__pycache__/url_fetcher.cpython-311.pyc +0 -0
  257. clarifai/runners/utils/__pycache__/url_fetcher.cpython-38.pyc +0 -0
  258. clarifai/runners/utils/__pycache__/url_fetcher.cpython-39.pyc +0 -0
  259. clarifai/runners/utils/data_handler.py +0 -231
  260. clarifai/runners/utils/data_handler_refract.py +0 -213
  261. clarifai/runners/utils/data_types.py +0 -469
  262. clarifai/runners/utils/logger.py +0 -0
  263. clarifai/runners/utils/openai_format.py +0 -87
  264. clarifai/schema/__pycache__/search.cpython-310.pyc +0 -0
  265. clarifai/schema/__pycache__/search.cpython-311.pyc +0 -0
  266. clarifai/urls/__pycache__/helper.cpython-310.pyc +0 -0
  267. clarifai/urls/__pycache__/helper.cpython-311.pyc +0 -0
  268. clarifai/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  269. clarifai/utils/__pycache__/__init__.cpython-311.pyc +0 -0
  270. clarifai/utils/__pycache__/__init__.cpython-39.pyc +0 -0
  271. clarifai/utils/__pycache__/cli.cpython-310.pyc +0 -0
  272. clarifai/utils/__pycache__/cli.cpython-311.pyc +0 -0
  273. clarifai/utils/__pycache__/config.cpython-311.pyc +0 -0
  274. clarifai/utils/__pycache__/constants.cpython-310.pyc +0 -0
  275. clarifai/utils/__pycache__/constants.cpython-311.pyc +0 -0
  276. clarifai/utils/__pycache__/logging.cpython-310.pyc +0 -0
  277. clarifai/utils/__pycache__/logging.cpython-311.pyc +0 -0
  278. clarifai/utils/__pycache__/misc.cpython-310.pyc +0 -0
  279. clarifai/utils/__pycache__/misc.cpython-311.pyc +0 -0
  280. clarifai/utils/__pycache__/model_train.cpython-310.pyc +0 -0
  281. clarifai/utils/__pycache__/model_train.cpython-311.pyc +0 -0
  282. clarifai/utils/__pycache__/protobuf.cpython-311.pyc +0 -0
  283. clarifai/utils/evaluation/__pycache__/__init__.cpython-311.pyc +0 -0
  284. clarifai/utils/evaluation/__pycache__/__init__.cpython-39.pyc +0 -0
  285. clarifai/utils/evaluation/__pycache__/helpers.cpython-311.pyc +0 -0
  286. clarifai/utils/evaluation/__pycache__/main.cpython-311.pyc +0 -0
  287. clarifai/utils/evaluation/__pycache__/main.cpython-39.pyc +0 -0
  288. clarifai/workflows/__pycache__/__init__.cpython-310.pyc +0 -0
  289. clarifai/workflows/__pycache__/__init__.cpython-311.pyc +0 -0
  290. clarifai/workflows/__pycache__/__init__.cpython-39.pyc +0 -0
  291. clarifai/workflows/__pycache__/export.cpython-310.pyc +0 -0
  292. clarifai/workflows/__pycache__/export.cpython-311.pyc +0 -0
  293. clarifai/workflows/__pycache__/utils.cpython-310.pyc +0 -0
  294. clarifai/workflows/__pycache__/utils.cpython-311.pyc +0 -0
  295. clarifai/workflows/__pycache__/validate.cpython-310.pyc +0 -0
  296. clarifai/workflows/__pycache__/validate.cpython-311.pyc +0 -0
  297. clarifai-11.3.0rc2.dist-info/RECORD +0 -322
  298. {clarifai-11.3.0rc2.dist-info → clarifai-11.4.0.dist-info}/entry_points.txt +0 -0
  299. {clarifai-11.3.0rc2.dist-info → clarifai-11.4.0.dist-info/licenses}/LICENSE +0 -0
  300. {clarifai-11.3.0rc2.dist-info → clarifai-11.4.0.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,6 @@
1
- import collections.abc as abc
2
1
  import inspect
3
2
  import json
4
- from collections import namedtuple
3
+ from collections import abc, namedtuple
5
4
  from typing import Dict, List, Tuple, get_args, get_origin
6
5
 
7
6
  import numpy as np
@@ -14,61 +13,72 @@ from google.protobuf.message import Message as MessageProto
14
13
  from clarifai.runners.utils import data_types, data_utils
15
14
  from clarifai.runners.utils.code_script import _get_base_type, _parse_default_value
16
15
  from clarifai.runners.utils.serializers import (
17
- AtomicFieldSerializer, JSONSerializer, ListSerializer, MessageSerializer,
18
- NamedFieldsSerializer, NDArraySerializer, Serializer, TupleSerializer)
16
+ AtomicFieldSerializer,
17
+ JSONSerializer,
18
+ ListSerializer,
19
+ MessageSerializer,
20
+ NamedFieldsSerializer,
21
+ NDArraySerializer,
22
+ Serializer,
23
+ TupleSerializer,
24
+ )
19
25
 
20
26
 
21
27
  def build_function_signature(func):
22
- '''
23
- Build a signature for the given function.
24
- '''
25
- sig = inspect.signature(func)
26
-
27
- # check if func is bound, and if not, remove self/cls
28
- if getattr(func, '__self__', None) is None and sig.parameters and list(
29
- sig.parameters.values())[0].name in ('self', 'cls'):
30
- sig = sig.replace(parameters=list(sig.parameters.values())[1:])
31
-
32
- return_annotation = sig.return_annotation
33
- if return_annotation == inspect.Parameter.empty:
34
- raise TypeError('Function must have a return annotation')
35
-
36
- input_sigs = []
37
- input_streaming = []
38
- for p in sig.parameters.values():
39
- model_type_field, _, streaming = build_variable_signature(p.name, p.annotation, p.default)
40
- input_sigs.append(model_type_field)
41
- input_streaming.append(streaming)
42
-
43
- output_sig, output_type, output_streaming = build_variable_signature(
44
- 'return', return_annotation, is_output=True)
45
- # TODO: flatten out "return" layer if not needed
46
-
47
- # check for streams and determine method type
48
- if sum(input_streaming) > 1:
49
- raise TypeError('streaming methods must have at most one streaming input')
50
- input_streaming = any(input_streaming)
51
- if not (input_streaming or output_streaming):
52
- method_type = 'UNARY_UNARY'
53
- elif not input_streaming and output_streaming:
54
- method_type = 'UNARY_STREAMING'
55
- elif input_streaming and output_streaming:
56
- method_type = 'STREAMING_STREAMING'
57
- else:
58
- raise TypeError('stream methods with streaming inputs must have streaming outputs')
59
-
60
- method_signature = resources_pb2.MethodSignature()
61
-
62
- method_signature.name = func.__name__
63
- method_signature.method_type = getattr(resources_pb2.RunnerMethodType, method_type)
64
- assert method_type in ('UNARY_UNARY', 'UNARY_STREAMING', 'STREAMING_STREAMING')
65
- # method_signature.method_type = method_type
66
- method_signature.description = inspect.cleandoc(func.__doc__ or '')
67
- # method_signature.annotations_json = json.dumps(_get_annotations_source(func))
68
-
69
- method_signature.input_fields.extend(input_sigs)
70
- method_signature.output_fields.append(output_sig)
71
- return method_signature
28
+ '''
29
+ Build a signature for the given function.
30
+ '''
31
+ sig = inspect.signature(func)
32
+
33
+ # check if func is bound, and if not, remove self/cls
34
+ if (
35
+ getattr(func, '__self__', None) is None
36
+ and sig.parameters
37
+ and list(sig.parameters.values())[0].name in ('self', 'cls')
38
+ ):
39
+ sig = sig.replace(parameters=list(sig.parameters.values())[1:])
40
+
41
+ return_annotation = sig.return_annotation
42
+ if return_annotation == inspect.Parameter.empty:
43
+ raise TypeError('Function must have a return annotation')
44
+
45
+ input_sigs = []
46
+ input_streaming = []
47
+ for p in sig.parameters.values():
48
+ model_type_field, _, streaming = build_variable_signature(p.name, p.annotation, p.default)
49
+ input_sigs.append(model_type_field)
50
+ input_streaming.append(streaming)
51
+
52
+ output_sig, output_type, output_streaming = build_variable_signature(
53
+ 'return', return_annotation, is_output=True
54
+ )
55
+ # TODO: flatten out "return" layer if not needed
56
+
57
+ # check for streams and determine method type
58
+ if sum(input_streaming) > 1:
59
+ raise TypeError('streaming methods must have at most one streaming input')
60
+ input_streaming = any(input_streaming)
61
+ if not (input_streaming or output_streaming):
62
+ method_type = 'UNARY_UNARY'
63
+ elif not input_streaming and output_streaming:
64
+ method_type = 'UNARY_STREAMING'
65
+ elif input_streaming and output_streaming:
66
+ method_type = 'STREAMING_STREAMING'
67
+ else:
68
+ raise TypeError('stream methods with streaming inputs must have streaming outputs')
69
+
70
+ method_signature = resources_pb2.MethodSignature()
71
+
72
+ method_signature.name = func.__name__
73
+ method_signature.method_type = getattr(resources_pb2.RunnerMethodType, method_type)
74
+ assert method_type in ('UNARY_UNARY', 'UNARY_STREAMING', 'STREAMING_STREAMING')
75
+ # method_signature.method_type = method_type
76
+ method_signature.description = inspect.cleandoc(func.__doc__ or '')
77
+ # method_signature.annotations_json = json.dumps(_get_annotations_source(func))
78
+
79
+ method_signature.input_fields.extend(input_sigs)
80
+ method_signature.output_fields.append(output_sig)
81
+ return method_signature
72
82
 
73
83
 
74
84
  # def _get_annotations_source(func):
@@ -91,336 +101,347 @@ def build_function_signature(func):
91
101
 
92
102
 
93
103
  def _process_input_field(field: resources_pb2.ModelTypeField) -> str:
94
- base_type = _get_base_type(field)
95
- if field.iterator:
96
- type_str = f"Iterator[{base_type}]"
97
- else:
98
- type_str = base_type
99
- default = _parse_default_value(field)
100
- param = f"{field.name}: {type_str}"
101
- if default is not None:
102
- param += f" = {default}"
103
- return param
104
+ base_type = _get_base_type(field)
105
+ if field.iterator:
106
+ type_str = f"Iterator[{base_type}]"
107
+ else:
108
+ type_str = base_type
109
+ default = _parse_default_value(field)
110
+ param = f"{field.name}: {type_str}"
111
+ if default is not None:
112
+ param += f" = {default}"
113
+ return param
104
114
 
105
115
 
106
116
  def _process_output_field(field: resources_pb2.ModelTypeField) -> str:
107
- base_type = _get_base_type(field)
108
- if field.iterator:
109
- return f"Iterator[{base_type}]"
110
- else:
111
- return base_type
117
+ base_type = _get_base_type(field)
118
+ if field.iterator:
119
+ return f"Iterator[{base_type}]"
120
+ else:
121
+ return base_type
112
122
 
113
123
 
114
124
  def get_method_signature(method_signature: resources_pb2.MethodSignature) -> str:
115
- """
125
+ """
116
126
  Get the method signature of a method in a model.
117
127
  """
118
- # Process input fields
119
- input_params = []
120
- for input_field in method_signature.input_fields:
121
- param_str = _process_input_field(input_field)
122
- input_params.append(param_str)
128
+ # Process input fields
129
+ input_params = []
130
+ for input_field in method_signature.input_fields:
131
+ param_str = _process_input_field(input_field)
132
+ input_params.append(param_str)
123
133
 
124
- # Process output field
125
- if not method_signature.output_fields:
126
- raise ValueError("MethodSignature must have at least one output field")
127
- output_field = method_signature.output_fields[0]
128
- return_type = _process_output_field(output_field)
134
+ # Process output field
135
+ if not method_signature.output_fields:
136
+ raise ValueError("MethodSignature must have at least one output field")
137
+ output_field = method_signature.output_fields[0]
138
+ return_type = _process_output_field(output_field)
129
139
 
130
- # Generate function signature
131
- function_def = f"def {method_signature.name}({', '.join(input_params)}) -> {return_type}:"
132
- return function_def
140
+ # Generate function signature
141
+ function_def = f"def {method_signature.name}({', '.join(input_params)}) -> {return_type}:"
142
+ return function_def
133
143
 
134
144
 
135
145
  def build_variable_signature(name, annotation, default=inspect.Parameter.empty, is_output=False):
136
- '''
137
- Build a data proto signature and get the normalized python type for the given annotation.
138
- '''
146
+ '''
147
+ Build a data proto signature and get the normalized python type for the given annotation.
148
+ '''
139
149
 
140
- # check valid names (should already be constrained by python naming, but check anyway)
141
- if not name.isidentifier():
142
- raise ValueError(f'Invalid variable name: {name}')
150
+ # check valid names (should already be constrained by python naming, but check anyway)
151
+ if not name.isidentifier():
152
+ raise ValueError(f'Invalid variable name: {name}')
143
153
 
144
- # get fields for each variable based on type
145
- tp, streaming = _normalize_type(annotation)
154
+ # get fields for each variable based on type
155
+ tp, streaming = _normalize_type(annotation)
146
156
 
147
- sig = resources_pb2.ModelTypeField()
148
- sig.name = name
149
- sig.iterator = streaming
157
+ sig = resources_pb2.ModelTypeField()
158
+ sig.name = name
159
+ sig.iterator = streaming
150
160
 
151
- if not is_output:
152
- sig.required = (default is inspect.Parameter.empty)
153
- if not sig.required:
154
- if isinstance(default, data_utils.InputField):
155
- sig = default.to_proto(sig)
156
- else:
157
- sig = data_utils.InputField.set_default(sig, default)
161
+ if not is_output:
162
+ sig.required = default is inspect.Parameter.empty
163
+ if not sig.required:
164
+ if isinstance(default, data_utils.InputField):
165
+ sig = default.to_proto(sig)
166
+ else:
167
+ sig = data_utils.InputField.set_default(sig, default)
158
168
 
159
- _fill_signature_type(sig, tp)
169
+ _fill_signature_type(sig, tp)
160
170
 
161
- return sig, type, streaming
171
+ return sig, type, streaming
162
172
 
163
173
 
164
174
  def _fill_signature_type(sig, tp):
165
- try:
166
- if tp in _DATA_TYPES:
167
- sig.type = _DATA_TYPES[tp].type
168
- return
169
- except TypeError:
170
- pass # not hashable type
171
-
172
- # Handle NamedFields with annotations
173
- # Check for dynamically generated NamedFields subclasses (from type annotations)
174
- if inspect.isclass(tp) and issubclass(tp, data_types.NamedFields) and hasattr(
175
- tp, '__annotations__'):
176
- sig.type = resources_pb2.ModelTypeField.DataType.NAMED_FIELDS
177
- for name, inner_type in tp.__annotations__.items():
178
- inner_sig = sig.type_args.add()
179
- inner_sig.name = name
180
- _fill_signature_type(inner_sig, inner_type)
181
- return
182
-
183
- # Handle NamedFields instances (dict-like)
184
- if isinstance(tp, data_types.NamedFields):
185
- sig.type = resources_pb2.ModelTypeField.DataType.NAMED_FIELDS
186
- for name, inner_type in tp.items():
187
- inner_sig = sig.type_args.add()
188
- inner_sig.name = name
189
- _fill_signature_type(inner_sig, inner_type)
190
- return
191
-
192
- origin = get_origin(tp)
193
- args = get_args(tp)
194
-
195
- # Handle Tuple type
196
- if origin == tuple:
197
- sig.type = resources_pb2.ModelTypeField.DataType.TUPLE
198
- for inner_type in args:
199
- inner_sig = sig.type_args.add()
200
- inner_sig.name = sig.name + '_item'
201
- _fill_signature_type(inner_sig, inner_type)
202
- return
203
-
204
- # Handle List type
205
- if origin == list:
206
- sig.type = resources_pb2.ModelTypeField.DataType.LIST
207
- inner_sig = sig.type_args.add()
208
- inner_sig.name = sig.name + '_item'
209
- _fill_signature_type(inner_sig, args[0])
210
- return
211
-
212
- raise TypeError(f'Unsupported type: {tp}')
175
+ try:
176
+ if tp in _DATA_TYPES:
177
+ sig.type = _DATA_TYPES[tp].type
178
+ return
179
+ except TypeError:
180
+ pass # not hashable type
181
+
182
+ # Handle NamedFields with annotations
183
+ # Check for dynamically generated NamedFields subclasses (from type annotations)
184
+ if (
185
+ inspect.isclass(tp)
186
+ and issubclass(tp, data_types.NamedFields)
187
+ and hasattr(tp, '__annotations__')
188
+ ):
189
+ sig.type = resources_pb2.ModelTypeField.DataType.NAMED_FIELDS
190
+ for name, inner_type in tp.__annotations__.items():
191
+ inner_sig = sig.type_args.add()
192
+ inner_sig.name = name
193
+ _fill_signature_type(inner_sig, inner_type)
194
+ return
195
+
196
+ # Handle NamedFields instances (dict-like)
197
+ if isinstance(tp, data_types.NamedFields):
198
+ sig.type = resources_pb2.ModelTypeField.DataType.NAMED_FIELDS
199
+ for name, inner_type in tp.items():
200
+ inner_sig = sig.type_args.add()
201
+ inner_sig.name = name
202
+ _fill_signature_type(inner_sig, inner_type)
203
+ return
204
+
205
+ origin = get_origin(tp)
206
+ args = get_args(tp)
207
+
208
+ # Handle Tuple type
209
+ if origin is tuple:
210
+ sig.type = resources_pb2.ModelTypeField.DataType.TUPLE
211
+ for inner_type in args:
212
+ inner_sig = sig.type_args.add()
213
+ inner_sig.name = sig.name + '_item'
214
+ _fill_signature_type(inner_sig, inner_type)
215
+ return
216
+
217
+ # Handle List type
218
+ if origin is list:
219
+ sig.type = resources_pb2.ModelTypeField.DataType.LIST
220
+ inner_sig = sig.type_args.add()
221
+ inner_sig.name = sig.name + '_item'
222
+ _fill_signature_type(inner_sig, args[0])
223
+ return
224
+
225
+ raise TypeError(f'Unsupported type: {tp}')
213
226
 
214
227
 
215
228
  def serializer_from_signature(signature):
216
- '''
229
+ '''
217
230
  Get the serializer for the given signature.
218
231
  '''
219
- if signature.type in _SERIALIZERS_BY_TYPE_ENUM:
220
- return _SERIALIZERS_BY_TYPE_ENUM[signature.type]
221
- if signature.type == resources_pb2.ModelTypeField.DataType.LIST:
222
- return ListSerializer(serializer_from_signature(signature.type_args[0]))
223
- if signature.type == resources_pb2.ModelTypeField.DataType.TUPLE:
224
- return TupleSerializer([serializer_from_signature(sig) for sig in signature.type_args])
225
- if signature.type == resources_pb2.ModelTypeField.DataType.NAMED_FIELDS:
226
- return NamedFieldsSerializer(
227
- {sig.name: serializer_from_signature(sig)
228
- for sig in signature.type_args})
229
- raise ValueError(f'Unsupported type: {signature.type}')
232
+ if signature.type in _SERIALIZERS_BY_TYPE_ENUM:
233
+ return _SERIALIZERS_BY_TYPE_ENUM[signature.type]
234
+ if signature.type == resources_pb2.ModelTypeField.DataType.LIST:
235
+ return ListSerializer(serializer_from_signature(signature.type_args[0]))
236
+ if signature.type == resources_pb2.ModelTypeField.DataType.TUPLE:
237
+ return TupleSerializer([serializer_from_signature(sig) for sig in signature.type_args])
238
+ if signature.type == resources_pb2.ModelTypeField.DataType.NAMED_FIELDS:
239
+ return NamedFieldsSerializer(
240
+ {sig.name: serializer_from_signature(sig) for sig in signature.type_args}
241
+ )
242
+ raise ValueError(f'Unsupported type: {signature.type}')
230
243
 
231
244
 
232
245
  def signatures_to_json(signatures):
233
- assert isinstance(
234
- signatures, dict), 'Expected dict of signatures {name: signature}, got %s' % type(signatures)
235
- # TODO change to proto when ready
236
- signatures = {name: MessageToDict(sig) for name, sig in signatures.items()}
237
- return json.dumps(signatures)
246
+ assert isinstance(signatures, dict), (
247
+ 'Expected dict of signatures {name: signature}, got %s' % type(signatures)
248
+ )
249
+ # TODO change to proto when ready
250
+ signatures = {name: MessageToDict(sig) for name, sig in signatures.items()}
251
+ return json.dumps(signatures)
238
252
 
239
253
 
240
254
  def signatures_from_json(json_str):
241
- signatures_dict = json.loads(json_str)
242
- assert isinstance(signatures_dict, dict), "Expected JSON to decode into a dictionary"
255
+ signatures_dict = json.loads(json_str)
256
+ assert isinstance(signatures_dict, dict), "Expected JSON to decode into a dictionary"
243
257
 
244
- return {
245
- name: ParseDict(sig_dict, resources_pb2.MethodSignature())
246
- for name, sig_dict in signatures_dict.items()
247
- }
248
- # d = json.loads(json_str, object_pairs_hook=_SignatureDict)
249
- # return d
258
+ return {
259
+ name: ParseDict(sig_dict, resources_pb2.MethodSignature())
260
+ for name, sig_dict in signatures_dict.items()
261
+ }
262
+ # d = json.loads(json_str, object_pairs_hook=_SignatureDict)
263
+ # return d
250
264
 
251
265
 
252
266
  def signatures_to_yaml(signatures):
253
- # XXX go in/out of json to get the correct format and python dict types
254
- d = json.loads(signatures_to_json(signatures))
267
+ # XXX go in/out of json to get the correct format and python dict types
268
+ d = json.loads(signatures_to_json(signatures))
255
269
 
256
- def _filter_empty(d):
257
- if isinstance(d, (list, tuple)):
258
- return [_filter_empty(v) for v in d if v]
259
- if isinstance(d, dict):
260
- return {k: _filter_empty(v) for k, v in d.items() if v}
261
- return d
270
+ def _filter_empty(d):
271
+ if isinstance(d, (list, tuple)):
272
+ return [_filter_empty(v) for v in d if v]
273
+ if isinstance(d, dict):
274
+ return {k: _filter_empty(v) for k, v in d.items() if v}
275
+ return d
262
276
 
263
- return yaml.dump(_filter_empty(d), default_flow_style=False)
277
+ return yaml.dump(_filter_empty(d), default_flow_style=False)
264
278
 
265
279
 
266
280
  def signatures_from_yaml(yaml_str):
267
- d = yaml.safe_load(yaml_str)
268
- return signatures_from_json(json.dumps(d))
281
+ d = yaml.safe_load(yaml_str)
282
+ return signatures_from_json(json.dumps(d))
269
283
 
270
284
 
271
285
  def serialize(kwargs, signatures, proto=None, is_output=False):
272
- '''
273
- Serialize the given kwargs into the proto using the given signatures.
274
- '''
275
- if proto is None:
276
- proto = resources_pb2.Data()
277
- unknown = set(kwargs.keys()) - set(sig.name for sig in signatures)
278
- if unknown:
279
- if unknown == {'return'} and len(signatures) > 1:
280
- raise TypeError('Got a single return value, but expected multiple outputs {%s}' %
281
- ', '.join(sig.name for sig in signatures))
282
- raise TypeError('Got unexpected key: %s' % ', '.join(unknown))
283
- for sig_i, sig in enumerate(signatures):
284
- if sig.name not in kwargs:
285
- if sig.required:
286
- raise TypeError(f'Missing required argument: {sig.name}')
287
- continue # skip missing fields, they can be set to default on the server
288
- data = kwargs[sig.name]
289
- serializer = serializer_from_signature(sig)
290
- # TODO determine if any (esp the first) var can go in the proto without parts
291
- # and whether to put this in the signature or dynamically determine it
292
- # add the part to the proto
293
- part = proto.parts.add()
294
- part.id = sig.name
295
- serializer.serialize(part.data, data)
296
- return proto
286
+ '''
287
+ Serialize the given kwargs into the proto using the given signatures.
288
+ '''
289
+ if proto is None:
290
+ proto = resources_pb2.Data()
291
+ unknown = set(kwargs.keys()) - set(sig.name for sig in signatures)
292
+ if unknown:
293
+ if unknown == {'return'} and len(signatures) > 1:
294
+ raise TypeError(
295
+ 'Got a single return value, but expected multiple outputs {%s}'
296
+ % ', '.join(sig.name for sig in signatures)
297
+ )
298
+ raise TypeError('Got unexpected key: %s' % ', '.join(unknown))
299
+ for sig_i, sig in enumerate(signatures):
300
+ if sig.name not in kwargs:
301
+ if sig.required:
302
+ raise TypeError(f'Missing required argument: {sig.name}')
303
+ continue # skip missing fields, they can be set to default on the server
304
+ data = kwargs[sig.name]
305
+ serializer = serializer_from_signature(sig)
306
+ # TODO determine if any (esp the first) var can go in the proto without parts
307
+ # and whether to put this in the signature or dynamically determine it
308
+ # add the part to the proto
309
+ part = proto.parts.add()
310
+ part.id = sig.name
311
+ serializer.serialize(part.data, data)
312
+ return proto
297
313
 
298
314
 
299
315
  def deserialize(proto, signatures, inference_params={}, is_output=False):
300
- '''
301
- Deserialize the given proto into kwargs using the given signatures.
302
- '''
303
- if isinstance(signatures, dict):
304
- signatures = [signatures] # TODO update return key level and make consistnet
305
- kwargs = {}
306
- parts_by_name = {part.id: part for part in proto.parts}
307
- for sig_i, sig in enumerate(signatures):
308
- serializer = serializer_from_signature(sig)
309
- part = parts_by_name.get(sig.name)
310
- inference_params_value = inference_params.get(sig.name)
311
- if part is not None:
312
- kwargs[sig.name] = serializer.deserialize(part.data)
313
- elif inference_params_value is not None:
314
- kwargs[sig.name] = inference_params_value
315
- else:
316
- if sig_i == 0:
317
- # possible inlined first value
318
- value = serializer.deserialize(proto)
319
- if id(value) not in _ZERO_VALUE_IDS:
320
- # note missing values are not set to defaults, since they are not in parts
321
- # an actual zero value passed in must be set in an explicit part
322
- kwargs[sig.name] = value
323
- continue
324
-
325
- if sig.required or is_output: # TODO allow optional outputs?
326
- raise ValueError(f'Missing required field: {sig.name}')
327
- continue
328
- if len(kwargs) == 1 and 'return' in kwargs:
329
- return kwargs['return']
330
- return kwargs
316
+ '''
317
+ Deserialize the given proto into kwargs using the given signatures.
318
+ '''
319
+ if isinstance(signatures, dict):
320
+ signatures = [signatures] # TODO update return key level and make consistnet
321
+ kwargs = {}
322
+ parts_by_name = {part.id: part for part in proto.parts}
323
+ for sig_i, sig in enumerate(signatures):
324
+ serializer = serializer_from_signature(sig)
325
+ part = parts_by_name.get(sig.name)
326
+ inference_params_value = inference_params.get(sig.name)
327
+ if part is not None:
328
+ kwargs[sig.name] = serializer.deserialize(part.data)
329
+ elif inference_params_value is not None:
330
+ kwargs[sig.name] = inference_params_value
331
+ else:
332
+ if sig_i == 0:
333
+ # possible inlined first value
334
+ value = serializer.deserialize(proto)
335
+ if id(value) not in _ZERO_VALUE_IDS:
336
+ # note missing values are not set to defaults, since they are not in parts
337
+ # an actual zero value passed in must be set in an explicit part
338
+ kwargs[sig.name] = value
339
+ continue
340
+
341
+ if sig.required or is_output: # TODO allow optional outputs?
342
+ raise ValueError(f'Missing required field: {sig.name}')
343
+ continue
344
+ if len(kwargs) == 1 and 'return' in kwargs:
345
+ return kwargs['return']
346
+ return kwargs
331
347
 
332
348
 
333
349
  def get_stream_from_signature(signatures):
334
- '''
335
- Get the stream signature from the given signatures.
336
- '''
337
- for sig in signatures:
338
- if sig.iterator:
339
- return sig
340
- return None
350
+ '''
351
+ Get the stream signature from the given signatures.
352
+ '''
353
+ for sig in signatures:
354
+ if sig.iterator:
355
+ return sig
356
+ return None
341
357
 
342
358
 
343
359
  def _is_empty_proto_data(data):
344
- if isinstance(data, np.ndarray):
345
- return False
346
- if isinstance(data, MessageProto):
347
- return not data.ByteSize()
348
- return not data
360
+ if isinstance(data, np.ndarray):
361
+ return False
362
+ if isinstance(data, MessageProto):
363
+ return not data.ByteSize()
364
+ return not data
349
365
 
350
366
 
351
367
  def _normalize_type(tp):
352
- '''
353
- Normalize the types for the given parameter.
354
- Returns the normalized type and whether the parameter is streaming.
355
- '''
356
- # stream type indicates streaming, not part of the data itself
357
- # it can only be used at the top-level of the var type
358
- streaming = (get_origin(tp) in [abc.Iterator, abc.Generator, abc.Iterable])
359
- if streaming:
360
- tp = get_args(tp)[0]
368
+ '''
369
+ Normalize the types for the given parameter.
370
+ Returns the normalized type and whether the parameter is streaming.
371
+ '''
372
+ # stream type indicates streaming, not part of the data itself
373
+ # it can only be used at the top-level of the var type
374
+ streaming = get_origin(tp) in [abc.Iterator, abc.Generator, abc.Iterable]
375
+ if streaming:
376
+ tp = get_args(tp)[0]
361
377
 
362
- return _normalize_data_type(tp), streaming
378
+ return _normalize_data_type(tp), streaming
363
379
 
364
380
 
365
381
  def _normalize_data_type(tp):
366
- # container types that need to be serialized as parts
367
- if get_origin(tp) == list and get_args(tp):
368
- return List[_normalize_data_type(get_args(tp)[0])]
369
-
370
- if get_origin(tp) == tuple:
371
- if not get_args(tp):
372
- raise TypeError('Tuple must have types specified')
373
- return Tuple[tuple(_normalize_data_type(val) for val in get_args(tp))]
374
-
375
- if isinstance(tp, (tuple, list)):
376
- return Tuple[tuple(_normalize_data_type(val) for val in tp)]
377
-
378
- if tp == data_types.NamedFields:
379
- raise TypeError('NamedFields must have types specified')
380
-
381
- # Handle dynamically generated NamedFields subclasses with annotations
382
- if isinstance(tp, type) and issubclass(tp, data_types.NamedFields) and hasattr(
383
- tp, '__annotations__'):
384
- return data_types.NamedFields(
385
- **{k: _normalize_data_type(v)
386
- for k, v in tp.__annotations__.items()})
387
-
388
- if isinstance(tp, (dict, data_types.NamedFields)):
389
- return data_types.NamedFields(**{name: _normalize_data_type(val) for name, val in tp.items()})
390
-
391
- # check if numpy array type, and if so, use ndarray
392
- if get_origin(tp) == np.ndarray:
393
- return np.ndarray
394
-
395
- # check for PIL images (sometimes types use the module, sometimes the class)
396
- # set these to use the Image data handler
397
- if tp in (data_types.Image, PIL.Image.Image):
398
- return data_types.Image
399
-
400
- if tp == PIL.Image:
401
- raise TypeError('Use PIL.Image.Image instead of PIL.Image module')
402
-
403
- # jsonable list and dict, these can be serialized as json
404
- # (tuple we want to keep as a tuple for args and returns, so don't include here)
405
- if tp in (list, dict, Dict) or (get_origin(tp) in (list, dict, Dict) and _is_jsonable(tp)):
406
- return data_types.JSON
407
-
408
- # check for known data types
409
- try:
410
- if tp in _DATA_TYPES:
411
- return tp
412
- except TypeError:
413
- pass # not hashable type
382
+ # container types that need to be serialized as parts
383
+ if get_origin(tp) is list and get_args(tp):
384
+ return List[_normalize_data_type(get_args(tp)[0])]
385
+
386
+ if get_origin(tp) is tuple:
387
+ if not get_args(tp):
388
+ raise TypeError('Tuple must have types specified')
389
+ return Tuple[tuple(_normalize_data_type(val) for val in get_args(tp))]
390
+
391
+ if isinstance(tp, (tuple, list)):
392
+ return Tuple[tuple(_normalize_data_type(val) for val in tp)]
393
+
394
+ if tp is data_types.NamedFields:
395
+ raise TypeError('NamedFields must have types specified')
396
+
397
+ # Handle dynamically generated NamedFields subclasses with annotations
398
+ if (
399
+ isinstance(tp, type)
400
+ and issubclass(tp, data_types.NamedFields)
401
+ and hasattr(tp, '__annotations__')
402
+ ):
403
+ return data_types.NamedFields(
404
+ **{k: _normalize_data_type(v) for k, v in tp.__annotations__.items()}
405
+ )
406
+
407
+ if isinstance(tp, (dict, data_types.NamedFields)):
408
+ return data_types.NamedFields(
409
+ **{name: _normalize_data_type(val) for name, val in tp.items()}
410
+ )
411
+
412
+ # check if numpy array type, and if so, use ndarray
413
+ if get_origin(tp) is np.ndarray:
414
+ return np.ndarray
415
+
416
+ # check for PIL images (sometimes types use the module, sometimes the class)
417
+ # set these to use the Image data handler
418
+ if tp in (data_types.Image, PIL.Image.Image):
419
+ return data_types.Image
420
+
421
+ if tp is PIL.Image:
422
+ raise TypeError('Use PIL.Image.Image instead of PIL.Image module')
423
+
424
+ # jsonable list and dict, these can be serialized as json
425
+ # (tuple we want to keep as a tuple for args and returns, so don't include here)
426
+ if tp in (list, dict, Dict) or (get_origin(tp) in (list, dict, Dict) and _is_jsonable(tp)):
427
+ return data_types.JSON
428
+
429
+ # check for known data types
430
+ try:
431
+ if tp in _DATA_TYPES:
432
+ return tp
433
+ except TypeError:
434
+ pass # not hashable type
414
435
 
415
- raise TypeError(f'Unsupported type: {tp}')
436
+ raise TypeError(f'Unsupported type: {tp}')
416
437
 
417
438
 
418
439
  def _is_jsonable(tp):
419
- if tp in (dict, list, tuple, str, int, float, bool, type(None)):
420
- return True
421
- if get_origin(tp) in (tuple, list, dict):
422
- return all(_is_jsonable(val) for val in get_args(tp))
423
- return False
440
+ if tp in (dict, list, tuple, str, int, float, bool, type(None)):
441
+ return True
442
+ if get_origin(tp) in (tuple, list, dict):
443
+ return all(_is_jsonable(val) for val in get_args(tp))
444
+ return False
424
445
 
425
446
 
426
447
  # type: name of the data type
@@ -432,77 +453,81 @@ _ZERO_VALUE_IDS = {id(None), id(''), id(b''), id(0), id(0.0), id(False)}
432
453
 
433
454
  # simple, non-container types that correspond directly to a data field
434
455
  _DATA_TYPES = {
435
- str:
436
- _DataType(resources_pb2.ModelTypeField.DataType.STR,
437
- AtomicFieldSerializer('string_value')),
438
- bytes:
439
- _DataType(resources_pb2.ModelTypeField.DataType.BYTES,
440
- AtomicFieldSerializer('bytes_value')),
441
- int:
442
- _DataType(resources_pb2.ModelTypeField.DataType.INT, AtomicFieldSerializer('int_value')),
443
- float:
444
- _DataType(resources_pb2.ModelTypeField.DataType.FLOAT,
445
- AtomicFieldSerializer('float_value')),
446
- bool:
447
- _DataType(resources_pb2.ModelTypeField.DataType.BOOL, AtomicFieldSerializer('bool_value')),
448
- np.ndarray:
449
- _DataType(resources_pb2.ModelTypeField.DataType.NDARRAY, NDArraySerializer('ndarray')),
450
- data_types.JSON:
451
- _DataType(resources_pb2.ModelTypeField.DataType.JSON_DATA, JSONSerializer('string_value')
452
- ), # TODO change to json_value when new proto is ready
453
- data_types.Text:
454
- _DataType(resources_pb2.ModelTypeField.DataType.TEXT,
455
- MessageSerializer('text', data_types.Text)),
456
- data_types.Image:
457
- _DataType(resources_pb2.ModelTypeField.DataType.IMAGE,
458
- MessageSerializer('image', data_types.Image)),
459
- data_types.Concept:
460
- _DataType(resources_pb2.ModelTypeField.DataType.CONCEPT,
461
- MessageSerializer('concepts', data_types.Concept)),
462
- data_types.Region:
463
- _DataType(resources_pb2.ModelTypeField.DataType.REGION,
464
- MessageSerializer('regions', data_types.Region)),
465
- data_types.Frame:
466
- _DataType(resources_pb2.ModelTypeField.DataType.FRAME,
467
- MessageSerializer('frames', data_types.Frame)),
468
- data_types.Audio:
469
- _DataType(resources_pb2.ModelTypeField.DataType.AUDIO,
470
- MessageSerializer('audio', data_types.Audio)),
471
- data_types.Video:
472
- _DataType(resources_pb2.ModelTypeField.DataType.VIDEO,
473
- MessageSerializer('video', data_types.Video)),
456
+ str: _DataType(
457
+ resources_pb2.ModelTypeField.DataType.STR, AtomicFieldSerializer('string_value')
458
+ ),
459
+ bytes: _DataType(
460
+ resources_pb2.ModelTypeField.DataType.BYTES, AtomicFieldSerializer('bytes_value')
461
+ ),
462
+ int: _DataType(resources_pb2.ModelTypeField.DataType.INT, AtomicFieldSerializer('int_value')),
463
+ float: _DataType(
464
+ resources_pb2.ModelTypeField.DataType.FLOAT, AtomicFieldSerializer('float_value')
465
+ ),
466
+ bool: _DataType(
467
+ resources_pb2.ModelTypeField.DataType.BOOL, AtomicFieldSerializer('bool_value')
468
+ ),
469
+ np.ndarray: _DataType(
470
+ resources_pb2.ModelTypeField.DataType.NDARRAY, NDArraySerializer('ndarray')
471
+ ),
472
+ data_types.JSON: _DataType(
473
+ resources_pb2.ModelTypeField.DataType.JSON_DATA, JSONSerializer('string_value')
474
+ ), # TODO change to json_value when new proto is ready
475
+ data_types.Text: _DataType(
476
+ resources_pb2.ModelTypeField.DataType.TEXT, MessageSerializer('text', data_types.Text)
477
+ ),
478
+ data_types.Image: _DataType(
479
+ resources_pb2.ModelTypeField.DataType.IMAGE, MessageSerializer('image', data_types.Image)
480
+ ),
481
+ data_types.Concept: _DataType(
482
+ resources_pb2.ModelTypeField.DataType.CONCEPT,
483
+ MessageSerializer('concepts', data_types.Concept),
484
+ ),
485
+ data_types.Region: _DataType(
486
+ resources_pb2.ModelTypeField.DataType.REGION,
487
+ MessageSerializer('regions', data_types.Region),
488
+ ),
489
+ data_types.Frame: _DataType(
490
+ resources_pb2.ModelTypeField.DataType.FRAME, MessageSerializer('frames', data_types.Frame)
491
+ ),
492
+ data_types.Audio: _DataType(
493
+ resources_pb2.ModelTypeField.DataType.AUDIO, MessageSerializer('audio', data_types.Audio)
494
+ ),
495
+ data_types.Video: _DataType(
496
+ resources_pb2.ModelTypeField.DataType.VIDEO, MessageSerializer('video', data_types.Video)
497
+ ),
474
498
  }
475
499
 
476
500
  _SERIALIZERS_BY_TYPE_ENUM = {dt.type: dt.serializer for dt in _DATA_TYPES.values()}
477
501
 
478
502
 
479
503
  class CompatibilitySerializer(Serializer):
480
- '''
481
- Serialization of basic value types, used for backwards compatibility
482
- with older models that don't have type signatures.
483
- '''
484
-
485
- def serialize(self, data_proto, value):
486
- tp = _normalize_data_type(type(value))
504
+ '''
505
+ Serialization of basic value types, used for backwards compatibility
506
+ with older models that don't have type signatures.
507
+ '''
487
508
 
488
- try:
489
- serializer = _DATA_TYPES[tp].serializer
490
- except (KeyError, TypeError):
491
- raise TypeError(f'serializer currently only supports basic types, got {tp}')
492
-
493
- serializer.serialize(data_proto, value)
494
-
495
- def deserialize(self, data_proto):
496
- fields = [k.name for k, _ in data_proto.ListFields()]
497
- if 'parts' in fields:
498
- raise ValueError('serializer does not support parts')
499
- serializers = [
500
- serializer for serializer in _SERIALIZERS_BY_TYPE_ENUM.values()
501
- if serializer.field_name in fields
502
- ]
503
- if not serializers:
504
- raise ValueError('Returned data not recognized')
505
- if len(serializers) != 1:
506
- raise ValueError('Only single output supported for serializer')
507
- serializer = serializers[0]
508
- return serializer.deserialize(data_proto)
509
+ def serialize(self, data_proto, value):
510
+ tp = _normalize_data_type(type(value))
511
+
512
+ try:
513
+ serializer = _DATA_TYPES[tp].serializer
514
+ except (KeyError, TypeError):
515
+ raise TypeError(f'serializer currently only supports basic types, got {tp}')
516
+
517
+ serializer.serialize(data_proto, value)
518
+
519
+ def deserialize(self, data_proto):
520
+ fields = [k.name for k, _ in data_proto.ListFields()]
521
+ if 'parts' in fields:
522
+ raise ValueError('serializer does not support parts')
523
+ serializers = [
524
+ serializer
525
+ for serializer in _SERIALIZERS_BY_TYPE_ENUM.values()
526
+ if serializer.field_name in fields
527
+ ]
528
+ if not serializers:
529
+ raise ValueError('Returned data not recognized')
530
+ if len(serializers) != 1:
531
+ raise ValueError('Only single output supported for serializer')
532
+ serializer = serializers[0]
533
+ return serializer.deserialize(data_proto)