truss 0.10.0rc1__py3-none-any.whl → 0.60.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truss might be problematic. Click here for more details.

Files changed (362) hide show
  1. truss/__init__.py +10 -3
  2. truss/api/__init__.py +123 -0
  3. truss/api/definitions.py +51 -0
  4. truss/base/constants.py +116 -0
  5. truss/base/custom_types.py +29 -0
  6. truss/{errors.py → base/errors.py} +4 -0
  7. truss/base/trt_llm_config.py +310 -0
  8. truss/{truss_config.py → base/truss_config.py} +344 -31
  9. truss/{truss_spec.py → base/truss_spec.py} +20 -6
  10. truss/{validation.py → base/validation.py} +60 -11
  11. truss/cli/cli.py +841 -88
  12. truss/{remote → cli}/remote_cli.py +2 -7
  13. truss/contexts/docker_build_setup.py +67 -0
  14. truss/contexts/image_builder/cache_warmer.py +2 -8
  15. truss/contexts/image_builder/image_builder.py +1 -1
  16. truss/contexts/image_builder/serving_image_builder.py +292 -46
  17. truss/contexts/image_builder/util.py +1 -3
  18. truss/contexts/local_loader/docker_build_emulator.py +58 -0
  19. truss/contexts/local_loader/load_model_local.py +2 -2
  20. truss/contexts/local_loader/truss_module_loader.py +1 -1
  21. truss/contexts/local_loader/utils.py +1 -1
  22. truss/local/local_config.py +2 -6
  23. truss/local/local_config_handler.py +20 -5
  24. truss/patch/__init__.py +1 -0
  25. truss/patch/hash.py +4 -70
  26. truss/patch/signature.py +4 -16
  27. truss/patch/truss_dir_patch_applier.py +3 -78
  28. truss/remote/baseten/api.py +308 -23
  29. truss/remote/baseten/auth.py +3 -3
  30. truss/remote/baseten/core.py +257 -50
  31. truss/remote/baseten/custom_types.py +44 -0
  32. truss/remote/baseten/error.py +4 -0
  33. truss/remote/baseten/remote.py +369 -118
  34. truss/remote/baseten/service.py +118 -11
  35. truss/remote/baseten/utils/status.py +29 -0
  36. truss/remote/baseten/utils/tar.py +34 -22
  37. truss/remote/baseten/utils/transfer.py +36 -23
  38. truss/remote/remote_factory.py +14 -5
  39. truss/remote/truss_remote.py +72 -45
  40. truss/templates/base.Dockerfile.jinja +18 -16
  41. truss/templates/cache.Dockerfile.jinja +3 -3
  42. truss/{server → templates/control}/control/application.py +14 -35
  43. truss/{server → templates/control}/control/endpoints.py +39 -9
  44. truss/{server/control/patch/types.py → templates/control/control/helpers/custom_types.py} +13 -52
  45. truss/{server → templates/control}/control/helpers/inference_server_controller.py +4 -8
  46. truss/{server → templates/control}/control/helpers/inference_server_process_controller.py +2 -4
  47. truss/{server → templates/control}/control/helpers/inference_server_starter.py +5 -10
  48. truss/{server/control → templates/control/control/helpers}/truss_patch/model_code_patch_applier.py +8 -6
  49. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/model_container_patch_applier.py +18 -26
  50. truss/templates/control/control/helpers/truss_patch/requirement_name_identifier.py +66 -0
  51. truss/{server → templates/control}/control/server.py +11 -6
  52. truss/templates/control/requirements.txt +9 -0
  53. truss/templates/custom_python_dx/my_model.py +28 -0
  54. truss/templates/docker_server/proxy.conf.jinja +42 -0
  55. truss/templates/docker_server/supervisord.conf.jinja +27 -0
  56. truss/templates/docker_server_requirements.txt +1 -0
  57. truss/templates/server/common/errors.py +231 -0
  58. truss/{server → templates/server}/common/patches/whisper/patch.py +1 -0
  59. truss/{server/common/patches/__init__.py → templates/server/common/patches.py} +1 -3
  60. truss/{server → templates/server}/common/retry.py +1 -0
  61. truss/{server → templates/server}/common/schema.py +11 -9
  62. truss/templates/server/common/tracing.py +157 -0
  63. truss/templates/server/main.py +9 -0
  64. truss/templates/server/model_wrapper.py +961 -0
  65. truss/templates/server/requirements.txt +21 -0
  66. truss/templates/server/truss_server.py +447 -0
  67. truss/templates/server.Dockerfile.jinja +62 -14
  68. truss/templates/shared/dynamic_config_resolver.py +28 -0
  69. truss/templates/shared/lazy_data_resolver.py +164 -0
  70. truss/templates/shared/log_config.py +125 -0
  71. truss/{server → templates}/shared/secrets_resolver.py +1 -2
  72. truss/{server → templates}/shared/serialization.py +31 -9
  73. truss/{server → templates}/shared/util.py +3 -13
  74. truss/templates/trtllm-audio/model/model.py +49 -0
  75. truss/templates/trtllm-audio/packages/sigint_patch.py +14 -0
  76. truss/templates/trtllm-audio/packages/whisper_trt/__init__.py +215 -0
  77. truss/templates/trtllm-audio/packages/whisper_trt/assets.py +25 -0
  78. truss/templates/trtllm-audio/packages/whisper_trt/batching.py +52 -0
  79. truss/templates/trtllm-audio/packages/whisper_trt/custom_types.py +26 -0
  80. truss/templates/trtllm-audio/packages/whisper_trt/modeling.py +184 -0
  81. truss/templates/trtllm-audio/packages/whisper_trt/tokenizer.py +185 -0
  82. truss/templates/trtllm-audio/packages/whisper_trt/utils.py +245 -0
  83. truss/templates/trtllm-briton/src/extension.py +64 -0
  84. truss/tests/conftest.py +302 -94
  85. truss/tests/contexts/image_builder/test_serving_image_builder.py +74 -31
  86. truss/tests/contexts/local_loader/test_load_local.py +2 -2
  87. truss/tests/contexts/local_loader/test_truss_module_finder.py +1 -1
  88. truss/tests/patch/test_calc_patch.py +439 -127
  89. truss/tests/patch/test_dir_signature.py +3 -12
  90. truss/tests/patch/test_hash.py +1 -1
  91. truss/tests/patch/test_signature.py +1 -1
  92. truss/tests/patch/test_truss_dir_patch_applier.py +23 -11
  93. truss/tests/patch/test_types.py +2 -2
  94. truss/tests/remote/baseten/test_api.py +153 -58
  95. truss/tests/remote/baseten/test_auth.py +2 -1
  96. truss/tests/remote/baseten/test_core.py +160 -12
  97. truss/tests/remote/baseten/test_remote.py +489 -77
  98. truss/tests/remote/baseten/test_service.py +55 -0
  99. truss/tests/remote/test_remote_factory.py +16 -18
  100. truss/tests/remote/test_truss_remote.py +26 -17
  101. truss/tests/templates/control/control/helpers/test_context_managers.py +11 -0
  102. truss/tests/templates/control/control/helpers/test_model_container_patch_applier.py +184 -0
  103. truss/tests/templates/control/control/helpers/test_requirement_name_identifier.py +89 -0
  104. truss/tests/{server → templates/control}/control/test_server.py +79 -24
  105. truss/tests/{server → templates/control}/control/test_server_integration.py +24 -16
  106. truss/tests/templates/core/server/test_dynamic_config_resolver.py +108 -0
  107. truss/tests/templates/core/server/test_lazy_data_resolver.py +329 -0
  108. truss/tests/templates/core/server/test_lazy_data_resolver_v2.py +79 -0
  109. truss/tests/{server → templates}/core/server/test_secrets_resolver.py +1 -1
  110. truss/tests/{server → templates/server}/common/test_retry.py +3 -3
  111. truss/tests/templates/server/test_model_wrapper.py +248 -0
  112. truss/tests/{server → templates/server}/test_schema.py +3 -5
  113. truss/tests/{server/core/server/common → templates/server}/test_truss_server.py +8 -5
  114. truss/tests/test_build.py +9 -52
  115. truss/tests/test_config.py +336 -77
  116. truss/tests/test_context_builder_image.py +3 -11
  117. truss/tests/test_control_truss_patching.py +7 -12
  118. truss/tests/test_custom_server.py +38 -0
  119. truss/tests/test_data/context_builder_image_test/test.py +3 -0
  120. truss/tests/test_data/happy.ipynb +56 -0
  121. truss/tests/test_data/model_load_failure_test/config.yaml +2 -0
  122. truss/tests/test_data/model_load_failure_test/model/__init__.py +0 -0
  123. truss/tests/test_data/patch_ping_test_server/__init__.py +0 -0
  124. truss/{test_data → tests/test_data}/patch_ping_test_server/app.py +3 -9
  125. truss/{test_data → tests/test_data}/server.Dockerfile +20 -21
  126. truss/tests/test_data/server_conformance_test_truss/__init__.py +0 -0
  127. truss/tests/test_data/server_conformance_test_truss/model/__init__.py +0 -0
  128. truss/{test_data → tests/test_data}/server_conformance_test_truss/model/model.py +1 -3
  129. truss/tests/test_data/test_async_truss/__init__.py +0 -0
  130. truss/tests/test_data/test_async_truss/model/__init__.py +0 -0
  131. truss/tests/test_data/test_basic_truss/__init__.py +0 -0
  132. truss/tests/test_data/test_basic_truss/config.yaml +16 -0
  133. truss/tests/test_data/test_basic_truss/model/__init__.py +0 -0
  134. truss/tests/test_data/test_build_commands/__init__.py +0 -0
  135. truss/tests/test_data/test_build_commands/config.yaml +13 -0
  136. truss/tests/test_data/test_build_commands/model/__init__.py +0 -0
  137. truss/{test_data/test_streaming_async_generator_truss → tests/test_data/test_build_commands}/model/model.py +2 -3
  138. truss/tests/test_data/test_build_commands_failure/__init__.py +0 -0
  139. truss/tests/test_data/test_build_commands_failure/config.yaml +14 -0
  140. truss/tests/test_data/test_build_commands_failure/model/__init__.py +0 -0
  141. truss/tests/test_data/test_build_commands_failure/model/model.py +17 -0
  142. truss/tests/test_data/test_concurrency_truss/__init__.py +0 -0
  143. truss/tests/test_data/test_concurrency_truss/config.yaml +4 -0
  144. truss/tests/test_data/test_concurrency_truss/model/__init__.py +0 -0
  145. truss/tests/test_data/test_custom_server_truss/__init__.py +0 -0
  146. truss/tests/test_data/test_custom_server_truss/config.yaml +20 -0
  147. truss/tests/test_data/test_custom_server_truss/test_docker_image/Dockerfile +17 -0
  148. truss/tests/test_data/test_custom_server_truss/test_docker_image/README.md +10 -0
  149. truss/tests/test_data/test_custom_server_truss/test_docker_image/VERSION +1 -0
  150. truss/tests/test_data/test_custom_server_truss/test_docker_image/__init__.py +0 -0
  151. truss/tests/test_data/test_custom_server_truss/test_docker_image/app.py +19 -0
  152. truss/tests/test_data/test_custom_server_truss/test_docker_image/build_upload_new_image.sh +6 -0
  153. truss/tests/test_data/test_openai/__init__.py +0 -0
  154. truss/{test_data/test_basic_truss → tests/test_data/test_openai}/config.yaml +1 -2
  155. truss/tests/test_data/test_openai/model/__init__.py +0 -0
  156. truss/tests/test_data/test_openai/model/model.py +15 -0
  157. truss/tests/test_data/test_pyantic_v1/__init__.py +0 -0
  158. truss/tests/test_data/test_pyantic_v1/model/__init__.py +0 -0
  159. truss/tests/test_data/test_pyantic_v1/model/model.py +28 -0
  160. truss/tests/test_data/test_pyantic_v1/requirements.txt +1 -0
  161. truss/tests/test_data/test_pyantic_v2/__init__.py +0 -0
  162. truss/tests/test_data/test_pyantic_v2/config.yaml +13 -0
  163. truss/tests/test_data/test_pyantic_v2/model/__init__.py +0 -0
  164. truss/tests/test_data/test_pyantic_v2/model/model.py +30 -0
  165. truss/tests/test_data/test_pyantic_v2/requirements.txt +1 -0
  166. truss/tests/test_data/test_requirements_file_truss/__init__.py +0 -0
  167. truss/tests/test_data/test_requirements_file_truss/config.yaml +13 -0
  168. truss/tests/test_data/test_requirements_file_truss/model/__init__.py +0 -0
  169. truss/{test_data → tests/test_data}/test_requirements_file_truss/model/model.py +1 -0
  170. truss/tests/test_data/test_streaming_async_generator_truss/__init__.py +0 -0
  171. truss/tests/test_data/test_streaming_async_generator_truss/config.yaml +4 -0
  172. truss/tests/test_data/test_streaming_async_generator_truss/model/__init__.py +0 -0
  173. truss/tests/test_data/test_streaming_async_generator_truss/model/model.py +7 -0
  174. truss/tests/test_data/test_streaming_read_timeout/__init__.py +0 -0
  175. truss/tests/test_data/test_streaming_read_timeout/model/__init__.py +0 -0
  176. truss/tests/test_data/test_streaming_truss/__init__.py +0 -0
  177. truss/tests/test_data/test_streaming_truss/config.yaml +4 -0
  178. truss/tests/test_data/test_streaming_truss/model/__init__.py +0 -0
  179. truss/tests/test_data/test_streaming_truss_with_error/__init__.py +0 -0
  180. truss/tests/test_data/test_streaming_truss_with_error/model/__init__.py +0 -0
  181. truss/{test_data → tests/test_data}/test_streaming_truss_with_error/model/model.py +3 -11
  182. truss/tests/test_data/test_streaming_truss_with_error/packages/__init__.py +0 -0
  183. truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_1.py +5 -0
  184. truss/tests/test_data/test_streaming_truss_with_error/packages/helpers_2.py +2 -0
  185. truss/tests/test_data/test_streaming_truss_with_tracing/__init__.py +0 -0
  186. truss/tests/test_data/test_streaming_truss_with_tracing/config.yaml +43 -0
  187. truss/tests/test_data/test_streaming_truss_with_tracing/model/__init__.py +0 -0
  188. truss/tests/test_data/test_streaming_truss_with_tracing/model/model.py +65 -0
  189. truss/tests/test_data/test_trt_llm_truss/__init__.py +0 -0
  190. truss/tests/test_data/test_trt_llm_truss/config.yaml +15 -0
  191. truss/tests/test_data/test_trt_llm_truss/model/__init__.py +0 -0
  192. truss/tests/test_data/test_trt_llm_truss/model/model.py +15 -0
  193. truss/tests/test_data/test_truss/__init__.py +0 -0
  194. truss/tests/test_data/test_truss/config.yaml +4 -0
  195. truss/tests/test_data/test_truss/model/__init__.py +0 -0
  196. truss/tests/test_data/test_truss/model/dummy +0 -0
  197. truss/tests/test_data/test_truss/packages/__init__.py +0 -0
  198. truss/tests/test_data/test_truss/packages/test_package/__init__.py +0 -0
  199. truss/tests/test_data/test_truss_server_caching_truss/__init__.py +0 -0
  200. truss/tests/test_data/test_truss_server_caching_truss/model/__init__.py +0 -0
  201. truss/tests/test_data/test_truss_with_error/__init__.py +0 -0
  202. truss/tests/test_data/test_truss_with_error/config.yaml +4 -0
  203. truss/tests/test_data/test_truss_with_error/model/__init__.py +0 -0
  204. truss/tests/test_data/test_truss_with_error/model/model.py +8 -0
  205. truss/tests/test_data/test_truss_with_error/packages/__init__.py +0 -0
  206. truss/tests/test_data/test_truss_with_error/packages/helpers_1.py +5 -0
  207. truss/tests/test_data/test_truss_with_error/packages/helpers_2.py +2 -0
  208. truss/tests/test_docker.py +2 -1
  209. truss/tests/test_model_inference.py +1340 -292
  210. truss/tests/test_model_schema.py +33 -26
  211. truss/tests/test_testing_utilities_for_other_tests.py +50 -5
  212. truss/tests/test_truss_gatherer.py +3 -5
  213. truss/tests/test_truss_handle.py +62 -59
  214. truss/tests/test_util.py +2 -1
  215. truss/tests/test_validation.py +15 -13
  216. truss/tests/trt_llm/test_trt_llm_config.py +41 -0
  217. truss/tests/trt_llm/test_validation.py +91 -0
  218. truss/tests/util/test_config_checks.py +40 -0
  219. truss/tests/util/test_env_vars.py +14 -0
  220. truss/tests/util/test_path.py +10 -23
  221. truss/trt_llm/config_checks.py +43 -0
  222. truss/trt_llm/validation.py +42 -0
  223. truss/truss_handle/__init__.py +0 -0
  224. truss/truss_handle/build.py +122 -0
  225. truss/{decorators.py → truss_handle/decorators.py} +1 -1
  226. truss/truss_handle/patch/__init__.py +0 -0
  227. truss/{patch → truss_handle/patch}/calc_patch.py +146 -92
  228. truss/{types.py → truss_handle/patch/custom_types.py} +35 -27
  229. truss/{patch → truss_handle/patch}/dir_signature.py +1 -1
  230. truss/truss_handle/patch/hash.py +71 -0
  231. truss/{patch → truss_handle/patch}/local_truss_patch_applier.py +6 -4
  232. truss/truss_handle/patch/signature.py +22 -0
  233. truss/truss_handle/patch/truss_dir_patch_applier.py +87 -0
  234. truss/{readme_generator.py → truss_handle/readme_generator.py} +3 -2
  235. truss/{truss_gatherer.py → truss_handle/truss_gatherer.py} +3 -2
  236. truss/{truss_handle.py → truss_handle/truss_handle.py} +174 -78
  237. truss/util/.truss_ignore +3 -0
  238. truss/{docker.py → util/docker.py} +6 -2
  239. truss/util/download.py +6 -15
  240. truss/util/env_vars.py +41 -0
  241. truss/util/log_utils.py +52 -0
  242. truss/util/path.py +20 -20
  243. truss/util/requirements.py +11 -0
  244. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/METADATA +18 -16
  245. truss-0.60.0.dist-info/RECORD +324 -0
  246. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/WHEEL +1 -1
  247. truss-0.60.0.dist-info/entry_points.txt +4 -0
  248. truss_chains/__init__.py +71 -0
  249. truss_chains/definitions.py +756 -0
  250. truss_chains/deployment/__init__.py +0 -0
  251. truss_chains/deployment/code_gen.py +816 -0
  252. truss_chains/deployment/deployment_client.py +871 -0
  253. truss_chains/framework.py +1480 -0
  254. truss_chains/public_api.py +231 -0
  255. truss_chains/py.typed +0 -0
  256. truss_chains/pydantic_numpy.py +131 -0
  257. truss_chains/reference_code/reference_chainlet.py +34 -0
  258. truss_chains/reference_code/reference_model.py +10 -0
  259. truss_chains/remote_chainlet/__init__.py +0 -0
  260. truss_chains/remote_chainlet/model_skeleton.py +60 -0
  261. truss_chains/remote_chainlet/stub.py +380 -0
  262. truss_chains/remote_chainlet/utils.py +332 -0
  263. truss_chains/streaming.py +378 -0
  264. truss_chains/utils.py +178 -0
  265. CODE_OF_CONDUCT.md +0 -131
  266. CONTRIBUTING.md +0 -48
  267. README.md +0 -137
  268. context_builder.Dockerfile +0 -24
  269. truss/blob/blob_backend.py +0 -10
  270. truss/blob/blob_backend_registry.py +0 -23
  271. truss/blob/http_public_blob_backend.py +0 -23
  272. truss/build/__init__.py +0 -2
  273. truss/build/build.py +0 -143
  274. truss/build/configure.py +0 -63
  275. truss/cli/__init__.py +0 -2
  276. truss/cli/console.py +0 -5
  277. truss/cli/create.py +0 -5
  278. truss/config/trt_llm.py +0 -81
  279. truss/constants.py +0 -61
  280. truss/model_inference.py +0 -123
  281. truss/patch/types.py +0 -30
  282. truss/pytest.ini +0 -7
  283. truss/server/common/errors.py +0 -100
  284. truss/server/common/termination_handler_middleware.py +0 -64
  285. truss/server/common/truss_server.py +0 -389
  286. truss/server/control/patch/model_code_patch_applier.py +0 -46
  287. truss/server/control/patch/requirement_name_identifier.py +0 -17
  288. truss/server/inference_server.py +0 -29
  289. truss/server/model_wrapper.py +0 -434
  290. truss/server/shared/logging.py +0 -81
  291. truss/templates/trtllm/model/model.py +0 -97
  292. truss/templates/trtllm/packages/build_engine_utils.py +0 -34
  293. truss/templates/trtllm/packages/constants.py +0 -11
  294. truss/templates/trtllm/packages/schema.py +0 -216
  295. truss/templates/trtllm/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt +0 -246
  296. truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/1/model.py +0 -181
  297. truss/templates/trtllm/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt +0 -64
  298. truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/1/model.py +0 -260
  299. truss/templates/trtllm/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt +0 -99
  300. truss/templates/trtllm/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt +0 -208
  301. truss/templates/trtllm/packages/triton_client.py +0 -150
  302. truss/templates/trtllm/packages/utils.py +0 -43
  303. truss/test_data/context_builder_image_test/test.py +0 -4
  304. truss/test_data/happy.ipynb +0 -54
  305. truss/test_data/model_load_failure_test/config.yaml +0 -2
  306. truss/test_data/test_concurrency_truss/config.yaml +0 -2
  307. truss/test_data/test_streaming_async_generator_truss/config.yaml +0 -2
  308. truss/test_data/test_streaming_truss/config.yaml +0 -3
  309. truss/test_data/test_truss/config.yaml +0 -2
  310. truss/tests/server/common/test_termination_handler_middleware.py +0 -93
  311. truss/tests/server/control/test_model_container_patch_applier.py +0 -203
  312. truss/tests/server/core/server/common/test_util.py +0 -19
  313. truss/tests/server/test_model_wrapper.py +0 -87
  314. truss/util/data_structures.py +0 -16
  315. truss-0.10.0rc1.dist-info/RECORD +0 -216
  316. truss-0.10.0rc1.dist-info/entry_points.txt +0 -3
  317. truss/{server/shared → base}/__init__.py +0 -0
  318. truss/{server → templates/control}/control/helpers/context_managers.py +0 -0
  319. truss/{server/control → templates/control/control/helpers}/errors.py +0 -0
  320. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/__init__.py +0 -0
  321. truss/{server/control/patch → templates/control/control/helpers/truss_patch}/system_packages.py +0 -0
  322. truss/{test_data/annotated_types_truss/model → templates/server}/__init__.py +0 -0
  323. truss/{server → templates/server}/common/__init__.py +0 -0
  324. truss/{test_data/gcs_fix/model → templates/shared}/__init__.py +0 -0
  325. truss/templates/{trtllm → trtllm-briton}/README.md +0 -0
  326. truss/{test_data/server_conformance_test_truss/model → tests/test_data}/__init__.py +0 -0
  327. truss/{test_data/test_basic_truss/model → tests/test_data/annotated_types_truss}/__init__.py +0 -0
  328. truss/{test_data → tests/test_data}/annotated_types_truss/config.yaml +0 -0
  329. truss/{test_data/test_requirements_file_truss → tests/test_data/annotated_types_truss}/model/__init__.py +0 -0
  330. truss/{test_data → tests/test_data}/annotated_types_truss/model/model.py +0 -0
  331. truss/{test_data → tests/test_data}/auto-mpg.data +0 -0
  332. truss/{test_data → tests/test_data}/context_builder_image_test/Dockerfile +0 -0
  333. truss/{test_data/test_truss/model → tests/test_data/context_builder_image_test}/__init__.py +0 -0
  334. truss/{test_data/test_truss_server_caching_truss/model → tests/test_data/gcs_fix}/__init__.py +0 -0
  335. truss/{test_data → tests/test_data}/gcs_fix/config.yaml +0 -0
  336. truss/tests/{local → test_data/gcs_fix/model}/__init__.py +0 -0
  337. truss/{test_data → tests/test_data}/gcs_fix/model/model.py +0 -0
  338. truss/{test_data/test_truss/model/dummy → tests/test_data/model_load_failure_test/__init__.py} +0 -0
  339. truss/{test_data → tests/test_data}/model_load_failure_test/model/model.py +0 -0
  340. truss/{test_data → tests/test_data}/pima-indians-diabetes.csv +0 -0
  341. truss/{test_data → tests/test_data}/readme_int_example.md +0 -0
  342. truss/{test_data → tests/test_data}/readme_no_example.md +0 -0
  343. truss/{test_data → tests/test_data}/readme_str_example.md +0 -0
  344. truss/{test_data → tests/test_data}/server_conformance_test_truss/config.yaml +0 -0
  345. truss/{test_data → tests/test_data}/test_async_truss/config.yaml +0 -0
  346. truss/{test_data → tests/test_data}/test_async_truss/model/model.py +3 -3
  347. /truss/{test_data → tests/test_data}/test_basic_truss/model/model.py +0 -0
  348. /truss/{test_data → tests/test_data}/test_concurrency_truss/model/model.py +0 -0
  349. /truss/{test_data/test_requirements_file_truss → tests/test_data/test_pyantic_v1}/config.yaml +0 -0
  350. /truss/{test_data → tests/test_data}/test_requirements_file_truss/requirements.txt +0 -0
  351. /truss/{test_data → tests/test_data}/test_streaming_read_timeout/config.yaml +0 -0
  352. /truss/{test_data → tests/test_data}/test_streaming_read_timeout/model/model.py +0 -0
  353. /truss/{test_data → tests/test_data}/test_streaming_truss/model/model.py +0 -0
  354. /truss/{test_data → tests/test_data}/test_streaming_truss_with_error/config.yaml +0 -0
  355. /truss/{test_data → tests/test_data}/test_truss/examples.yaml +0 -0
  356. /truss/{test_data → tests/test_data}/test_truss/model/model.py +0 -0
  357. /truss/{test_data → tests/test_data}/test_truss/packages/test_package/test.py +0 -0
  358. /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/config.yaml +0 -0
  359. /truss/{test_data → tests/test_data}/test_truss_server_caching_truss/model/model.py +0 -0
  360. /truss/{patch → truss_handle/patch}/constants.py +0 -0
  361. /truss/{notebook.py → util/notebook.py} +0 -0
  362. {truss-0.10.0rc1.dist-info → truss-0.60.0.dist-info}/LICENSE +0 -0
@@ -0,0 +1,55 @@
1
+ from truss.remote.baseten import service
2
+
3
+
4
+ def test_model_invocation_url_prod():
5
+ url = service.URLConfig.invocation_url(
6
+ "https://api.baseten.co", service.URLConfig.MODEL, "123", "789", is_draft=False
7
+ )
8
+ assert url == "https://model-123.api.baseten.co/deployment/789/predict"
9
+
10
+
11
+ def test_model_invocation_url_draft():
12
+ url = service.URLConfig.invocation_url(
13
+ "https://api.baseten.co", service.URLConfig.MODEL, "123", "789", is_draft=True
14
+ )
15
+ assert url == "https://model-123.api.baseten.co/development/predict"
16
+
17
+
18
+ def test_chain_invocation_url_prod():
19
+ url = service.URLConfig.invocation_url(
20
+ "https://api.baseten.co", service.URLConfig.CHAIN, "abc", "666", is_draft=False
21
+ )
22
+ assert url == "https://chain-abc.api.baseten.co/deployment/666/run_remote"
23
+
24
+
25
+ def test_chain_invocation_url_draft():
26
+ url = service.URLConfig.invocation_url(
27
+ "https://api.baseten.co", service.URLConfig.CHAIN, "abc", "666", is_draft=True
28
+ )
29
+ assert url == "https://chain-abc.api.baseten.co/development/run_remote"
30
+
31
+
32
+ def test_model_status_page_url():
33
+ url = service.URLConfig.status_page_url(
34
+ "https://app.baseten.co", service.URLConfig.MODEL, "123"
35
+ )
36
+ assert url == "https://app.baseten.co/models/123/overview"
37
+
38
+
39
+ def test_chain_status_page_url():
40
+ url = service.URLConfig.status_page_url(
41
+ "https://app.baseten.co", service.URLConfig.CHAIN, "abc"
42
+ )
43
+ assert url == "https://app.baseten.co/chains/abc/overview"
44
+
45
+
46
+ def test_model_logs_url():
47
+ url = service.URLConfig.model_logs_url("https://app.baseten.co", "123", "789")
48
+ assert url == "https://app.baseten.co/models/123/logs/789"
49
+
50
+
51
+ def test_chain_logs_url():
52
+ url = service.URLConfig.chainlet_logs_url(
53
+ "https://app.baseten.co", "abc", "666", "543"
54
+ )
55
+ assert url == "https://app.baseten.co/chains/abc/logs/666/543"
@@ -2,7 +2,7 @@ from unittest import mock
2
2
 
3
3
  import pytest
4
4
  from truss.remote.remote_factory import RemoteFactory
5
- from truss.remote.truss_remote import RemoteConfig, TrussRemote, TrussService
5
+ from truss.remote.truss_remote import RemoteConfig, RemoteUser, TrussRemote
6
6
 
7
7
  SAMPLE_CONFIG = {"api_key": "test_key", "remote_url": "http://test.com"}
8
8
 
@@ -25,10 +25,9 @@ remote_provider=test_remote
25
25
  """
26
26
 
27
27
 
28
- class TestRemote(TrussRemote):
28
+ class TrussTestRemote(TrussRemote):
29
29
  def __init__(self, api_key, remote_url):
30
30
  self.api_key = api_key
31
- self.remote_url = remote_url
32
31
 
33
32
  def authenticate(self):
34
33
  return {"Authorization": self.api_key}
@@ -36,20 +35,19 @@ class TestRemote(TrussRemote):
36
35
  def push(self):
37
36
  return {"status": "success"}
38
37
 
39
- def get_remote_logs_url(self, service: TrussService) -> str:
40
- raise NotImplementedError
41
-
42
38
  def get_service(self, **kwargs):
43
39
  raise NotImplementedError
44
40
 
45
41
  def sync_truss_to_dev_version_by_name(self, model_name: str, target_directory: str):
46
42
  raise NotImplementedError
47
43
 
44
+ def whoami(self) -> RemoteUser:
45
+ return RemoteUser("test_user", "test_email")
46
+
48
47
 
49
48
  def mock_service_config():
50
49
  return RemoteConfig(
51
- name="mock-service",
52
- configs={"remote_provider": "test_remote", **SAMPLE_CONFIG},
50
+ name="mock-service", configs={"remote_provider": "test_remote", **SAMPLE_CONFIG}
53
51
  )
54
52
 
55
53
 
@@ -60,7 +58,7 @@ def mock_incorrect_service_config():
60
58
  )
61
59
 
62
60
 
63
- @mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TestRemote}, clear=True)
61
+ @mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TrussTestRemote}, clear=True)
64
62
  @mock.patch(
65
63
  "truss.remote.remote_factory.RemoteFactory.load_remote_config",
66
64
  return_value=mock_service_config(),
@@ -69,10 +67,10 @@ def test_create(mock_load_remote_config):
69
67
  service_name = "test_service"
70
68
  remote = RemoteFactory.create(service_name)
71
69
  mock_load_remote_config.assert_called_once_with(service_name)
72
- assert isinstance(remote, TestRemote)
70
+ assert isinstance(remote, TrussTestRemote)
73
71
 
74
72
 
75
- @mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TestRemote}, clear=True)
73
+ @mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TrussTestRemote}, clear=True)
76
74
  @mock.patch(
77
75
  "truss.remote.remote_factory.RemoteFactory.load_remote_config",
78
76
  return_value=mock_incorrect_service_config(),
@@ -83,7 +81,7 @@ def test_create_no_service(mock_load_remote_config):
83
81
  RemoteFactory.create(service_name)
84
82
 
85
83
 
86
- @mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TestRemote}, clear=True)
84
+ @mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TrussTestRemote}, clear=True)
87
85
  @mock.patch("builtins.open", new_callable=mock.mock_open, read_data=SAMPLE_TRUSSRC)
88
86
  @mock.patch("pathlib.Path.exists", return_value=True)
89
87
  def test_load_remote_config(mock_exists, mock_open):
@@ -92,7 +90,7 @@ def test_load_remote_config(mock_exists, mock_open):
92
90
  assert service.configs == {"remote_provider": "test_remote", **SAMPLE_CONFIG}
93
91
 
94
92
 
95
- @mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TestRemote}, clear=True)
93
+ @mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TrussTestRemote}, clear=True)
96
94
  @mock.patch("builtins.open", new_callable=mock.mock_open, read_data=SAMPLE_TRUSSRC)
97
95
  @mock.patch("pathlib.Path.exists", return_value=False)
98
96
  def test_load_remote_config_no_file(mock_exists, mock_open):
@@ -100,7 +98,7 @@ def test_load_remote_config_no_file(mock_exists, mock_open):
100
98
  RemoteFactory.load_remote_config("test")
101
99
 
102
100
 
103
- @mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TestRemote}, clear=True)
101
+ @mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TrussTestRemote}, clear=True)
104
102
  @mock.patch("builtins.open", new_callable=mock.mock_open, read_data=SAMPLE_TRUSSRC)
105
103
  @mock.patch("pathlib.Path.exists", return_value=True)
106
104
  def test_load_remote_config_no_service(mock_exists, mock_open):
@@ -108,13 +106,13 @@ def test_load_remote_config_no_service(mock_exists, mock_open):
108
106
  RemoteFactory.load_remote_config("nonexistent_service")
109
107
 
110
108
 
111
- @mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TestRemote}, clear=True)
109
+ @mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TrussTestRemote}, clear=True)
112
110
  def test_required_params():
113
- required_params = RemoteFactory.required_params(TestRemote)
111
+ required_params = RemoteFactory.required_params(TrussTestRemote)
114
112
  assert required_params == {"api_key", "remote_url"}
115
113
 
116
114
 
117
- @mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TestRemote}, clear=True)
115
+ @mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TrussTestRemote}, clear=True)
118
116
  @mock.patch(
119
117
  "builtins.open", new_callable=mock.mock_open, read_data=SAMPLE_TRUSSRC_NO_REMOTE
120
118
  )
@@ -125,7 +123,7 @@ def test_validate_remote_config_no_remote(mock_exists, mock_open):
125
123
  RemoteFactory.validate_remote_config(service.configs, "test")
126
124
 
127
125
 
128
- @mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TestRemote}, clear=True)
126
+ @mock.patch.dict(RemoteFactory.REGISTRY, {"test_remote": TrussTestRemote}, clear=True)
129
127
  @mock.patch(
130
128
  "builtins.open", new_callable=mock.mock_open, read_data=SAMPLE_TRUSSRC_NO_PARAMS
131
129
  )
@@ -1,3 +1,4 @@
1
+ from typing import Iterator
1
2
  from unittest import mock
2
3
 
3
4
  import pytest
@@ -7,28 +8,36 @@ from truss.remote.truss_remote import TrussService
7
8
  TEST_SERVICE_URL = "http://test.com"
8
9
 
9
10
 
10
- class TestTrussService(TrussService):
11
- def __init__(self, remote_url: str, is_draft: bool, **kwargs):
12
- super().__init__(remote_url, is_draft, **kwargs)
13
- self._remote_url = remote_url
11
+ class TrussTestService(TrussService):
12
+ def __init__(self, _service_url: str, is_draft: bool, **kwargs):
13
+ super().__init__(_service_url, is_draft, **kwargs)
14
14
 
15
15
  def authenticate(self):
16
- return {"Authorization": "Test"}
16
+ return {}
17
17
 
18
- def is_live(self):
19
- response = self._send_request(self._remote_url, "GET")
18
+ def is_live(self) -> bool:
19
+ response = self._send_request(self._service_url, "GET")
20
20
  if response.status_code == 200:
21
21
  return True
22
22
  return False
23
23
 
24
- def is_ready(self):
25
- response = self._send_request(self._remote_url, "GET")
24
+ def is_ready(self) -> bool:
25
+ response = self._send_request(self._service_url, "GET")
26
26
  if response.status_code == 200:
27
27
  return True
28
28
  return False
29
29
 
30
- def logs_url(self, base_url: str) -> str:
31
- return f"{base_url}/logs"
30
+ @property
31
+ def logs_url(self) -> str:
32
+ raise NotImplementedError()
33
+
34
+ @property
35
+ def predict_url(self) -> str:
36
+ return f"{self._service_url}/v1/models/model:predict"
37
+
38
+ def poll_deployment_status(self, sleep_secs: int = 1) -> Iterator[str]:
39
+ for status in ["DEPLOYING", "ACTIVE"]:
40
+ yield status
32
41
 
33
42
 
34
43
  def mock_successful_response():
@@ -46,37 +55,37 @@ def mock_unsuccessful_response():
46
55
 
47
56
  @mock.patch("requests.request", return_value=mock_successful_response())
48
57
  def test_is_live(mock_request):
49
- service = TestTrussService(TEST_SERVICE_URL, True)
58
+ service = TrussTestService(TEST_SERVICE_URL, True)
50
59
  assert service.is_live()
51
60
 
52
61
 
53
62
  @mock.patch("requests.request", return_value=mock_unsuccessful_response())
54
63
  def test_is_not_live(mock_request):
55
- service = TestTrussService(TEST_SERVICE_URL, True)
64
+ service = TrussTestService(TEST_SERVICE_URL, True)
56
65
  assert service.is_live() is False
57
66
 
58
67
 
59
68
  @mock.patch("requests.request", return_value=mock_successful_response())
60
69
  def test_is_ready(mock_request):
61
- service = TestTrussService(TEST_SERVICE_URL, True)
70
+ service = TrussTestService(TEST_SERVICE_URL, True)
62
71
  assert service.is_ready()
63
72
 
64
73
 
65
74
  @mock.patch("requests.request", return_value=mock_unsuccessful_response())
66
75
  def test_is_not_ready(mock_request):
67
- service = TestTrussService(TEST_SERVICE_URL, True)
76
+ service = TrussTestService(TEST_SERVICE_URL, True)
68
77
  assert service.is_ready() is False
69
78
 
70
79
 
71
80
  @mock.patch("requests.request", return_value=mock_successful_response())
72
81
  def test_predict(mock_request):
73
- service = TestTrussService(TEST_SERVICE_URL, True)
82
+ service = TrussTestService(TEST_SERVICE_URL, True)
74
83
  response = service.predict({"model_input": "test"})
75
84
  assert response.status_code == 200
76
85
 
77
86
 
78
87
  @mock.patch("requests.request", return_value=mock_successful_response())
79
88
  def test_predict_no_data(mock_request):
80
- service = TestTrussService(TEST_SERVICE_URL, True)
89
+ service = TrussTestService(TEST_SERVICE_URL, True)
81
90
  with pytest.raises(ValueError):
82
91
  service._send_request(TEST_SERVICE_URL, "POST")
@@ -0,0 +1,11 @@
1
+ import os
2
+
3
+ from truss.templates.control.control.helpers.context_managers import current_directory
4
+
5
+
6
+ def test_current_directory(tmp_path):
7
+ orig_cwd = os.getcwd()
8
+ with current_directory(tmp_path):
9
+ assert os.getcwd() == str(tmp_path)
10
+
11
+ assert os.getcwd() == orig_cwd
@@ -0,0 +1,184 @@
1
+ import os
2
+ import sys
3
+ from pathlib import Path
4
+ from unittest import mock
5
+
6
+ import pytest
7
+ from truss.base.truss_config import TrussConfig
8
+
9
+ # Needed to simulate the set up on the model docker container
10
+ sys.path.append(
11
+ str(
12
+ Path(__file__).parent.parent.parent.parent.parent.parent
13
+ / "templates"
14
+ / "control"
15
+ / "control"
16
+ )
17
+ )
18
+
19
+ # Have to use imports in this form, otherwise isinstance checks fail on helper classes
20
+ from helpers.truss_patch.model_container_patch_applier import ( # noqa
21
+ ModelContainerPatchApplier,
22
+ )
23
+ from helpers.custom_types import ( # noqa
24
+ Action,
25
+ ConfigPatch,
26
+ EnvVarPatch,
27
+ ExternalDataPatch,
28
+ ModelCodePatch,
29
+ PackagePatch,
30
+ Patch,
31
+ PatchType,
32
+ )
33
+
34
+
35
+ @pytest.fixture
36
+ def patch_applier(truss_container_fs):
37
+ return ModelContainerPatchApplier(truss_container_fs / "app", mock.Mock())
38
+
39
+
40
+ def test_patch_applier_model_code_patch_add(
41
+ patch_applier: ModelContainerPatchApplier, truss_container_fs
42
+ ):
43
+ patch = Patch(
44
+ type=PatchType.MODEL_CODE,
45
+ body=ModelCodePatch(action=Action.ADD, path="dummy", content=""),
46
+ )
47
+ patch_applier(patch, os.environ.copy())
48
+ assert (truss_container_fs / "app" / "model" / "dummy").exists()
49
+
50
+
51
+ def test_patch_applier_model_code_patch_remove(
52
+ patch_applier: ModelContainerPatchApplier, truss_container_fs
53
+ ):
54
+ patch = Patch(
55
+ type=PatchType.MODEL_CODE,
56
+ body=ModelCodePatch(action=Action.REMOVE, path="model.py"),
57
+ )
58
+ assert (truss_container_fs / "app" / "model" / "model.py").exists()
59
+ patch_applier(patch, os.environ.copy())
60
+ assert not (truss_container_fs / "app" / "model" / "model.py").exists()
61
+
62
+
63
+ def test_patch_applier_model_code_patch_update(
64
+ patch_applier: ModelContainerPatchApplier, truss_container_fs
65
+ ):
66
+ new_model_file_content = """
67
+ class Model:
68
+ pass
69
+ """
70
+ patch = Patch(
71
+ type=PatchType.MODEL_CODE,
72
+ body=ModelCodePatch(
73
+ action=Action.UPDATE, path="model.py", content=new_model_file_content
74
+ ),
75
+ )
76
+ patch_applier(patch, os.environ.copy())
77
+ assert (
78
+ truss_container_fs / "app" / "model" / "model.py"
79
+ ).read_text() == new_model_file_content
80
+
81
+
82
+ def test_patch_applier_package_patch_add(
83
+ patch_applier: ModelContainerPatchApplier, truss_container_fs
84
+ ):
85
+ patch = Patch(
86
+ type=PatchType.PACKAGE,
87
+ body=PackagePatch(
88
+ action=Action.ADD, path="test_package/test.py", content="foobar"
89
+ ),
90
+ )
91
+ patch_applier(patch, os.environ.copy())
92
+ assert (truss_container_fs / "packages" / "test_package" / "test.py").exists()
93
+
94
+
95
+ def test_patch_applier_package_patch_remove(
96
+ patch_applier: ModelContainerPatchApplier, truss_container_fs
97
+ ):
98
+ patch = Patch(
99
+ type=PatchType.PACKAGE,
100
+ body=PackagePatch(action=Action.REMOVE, path="test_package/test.py"),
101
+ )
102
+ assert (truss_container_fs / "packages" / "test_package" / "test.py").exists()
103
+ patch_applier(patch, os.environ.copy())
104
+ assert not (truss_container_fs / "packages" / "test_package" / "test.py").exists()
105
+
106
+
107
+ def test_patch_applier_package_patch_update(
108
+ patch_applier: ModelContainerPatchApplier, truss_container_fs
109
+ ):
110
+ new_package_content = """X = 2"""
111
+ patch = Patch(
112
+ type=PatchType.PACKAGE,
113
+ body=PackagePatch(
114
+ action=Action.UPDATE,
115
+ path="test_package/test.py",
116
+ content=new_package_content,
117
+ ),
118
+ )
119
+ patch_applier(patch, os.environ.copy())
120
+ assert (
121
+ truss_container_fs / "packages" / "test_package" / "test.py"
122
+ ).read_text() == new_package_content
123
+
124
+
125
+ def test_patch_applier_config_patch_update(
126
+ patch_applier: ModelContainerPatchApplier, truss_container_fs
127
+ ):
128
+ new_config_dict = {"model_name": "foobar"}
129
+ patch = Patch(
130
+ type=PatchType.CONFIG,
131
+ body=ConfigPatch(action=Action.UPDATE, config=new_config_dict),
132
+ )
133
+ patch_applier(patch, os.environ.copy())
134
+ new_config = TrussConfig.from_yaml(truss_container_fs / "app" / "config.yaml")
135
+ assert new_config.model_name == "foobar"
136
+
137
+
138
+ def test_patch_applier_env_var_patch_update(patch_applier: ModelContainerPatchApplier):
139
+ env_var_dict = {"FOO": "BAR"}
140
+ patch = Patch(
141
+ type=PatchType.ENVIRONMENT_VARIABLE,
142
+ body=EnvVarPatch(action=Action.UPDATE, item={"FOO": "BAR-PATCHED"}),
143
+ )
144
+ patch_applier(patch, env_var_dict)
145
+ assert env_var_dict["FOO"] == "BAR-PATCHED"
146
+
147
+
148
+ def test_patch_applier_env_var_patch_add(patch_applier: ModelContainerPatchApplier):
149
+ env_var_dict = {"FOO": "BAR"}
150
+ patch = Patch(
151
+ type=PatchType.ENVIRONMENT_VARIABLE,
152
+ body=EnvVarPatch(action=Action.ADD, item={"BAR": "FOO"}),
153
+ )
154
+ patch_applier(patch, env_var_dict)
155
+ assert env_var_dict["FOO"] == "BAR"
156
+ assert env_var_dict["BAR"] == "FOO"
157
+
158
+
159
+ def test_patch_applier_env_var_patch_remove(patch_applier: ModelContainerPatchApplier):
160
+ env_var_dict = {"FOO": "BAR"}
161
+ patch = Patch(
162
+ type=PatchType.ENVIRONMENT_VARIABLE,
163
+ body=EnvVarPatch(action=Action.REMOVE, item={"FOO": "BAR"}),
164
+ )
165
+ patch_applier(patch, env_var_dict)
166
+ with pytest.raises(KeyError):
167
+ _ = env_var_dict["FOO"]
168
+
169
+
170
+ def test_patch_applier_external_data_patch_add(
171
+ patch_applier: ModelContainerPatchApplier, truss_container_fs
172
+ ):
173
+ patch = Patch(
174
+ type=PatchType.EXTERNAL_DATA,
175
+ body=ExternalDataPatch(
176
+ action=Action.ADD,
177
+ item={
178
+ "url": "https://raw.githubusercontent.com/basetenlabs/truss/main/docs/favicon.svg",
179
+ "local_data_path": "truss_icon",
180
+ },
181
+ ),
182
+ )
183
+ patch_applier(patch, os.environ.copy())
184
+ assert (truss_container_fs / "app" / "data" / "truss_icon").exists()
@@ -0,0 +1,89 @@
1
+ import pytest
2
+ from truss.templates.control.control.helpers.truss_patch.requirement_name_identifier import (
3
+ RequirementMeta,
4
+ identify_requirement_name,
5
+ reqs_by_name,
6
+ )
7
+
8
+
9
+ @pytest.mark.parametrize(
10
+ "req, expected_name",
11
+ [
12
+ ("pytorch", "pytorch"),
13
+ (
14
+ "git+https://github.com/huggingface/transformers.git#egg=transformers",
15
+ "git+github.com/huggingface/transformers.git",
16
+ ),
17
+ (
18
+ "git+https://github.com/huggingface/transformers.git",
19
+ "git+github.com/huggingface/transformers.git",
20
+ ),
21
+ (
22
+ "git+https://github.com/huggingface/transformers.git@main#egg=transformers",
23
+ "git+github.com/huggingface/transformers.git",
24
+ ),
25
+ (
26
+ " git+https://github.com/huggingface/transformers.git ",
27
+ "git+github.com/huggingface/transformers.git",
28
+ ),
29
+ ("pytorch==1.0", "pytorch"),
30
+ ("pytorch>=1.0", "pytorch"),
31
+ ("pytorch<=1.0", "pytorch"),
32
+ ],
33
+ )
34
+ def test_identify_requirement_name(req, expected_name):
35
+ assert expected_name == identify_requirement_name(req)
36
+
37
+
38
+ def test_reqs_by_name():
39
+ reqs = ["pytorch", " ", "jinja==1.0"]
40
+ assert reqs_by_name(reqs) == {"pytorch": "pytorch", "jinja": "jinja==1.0"}
41
+
42
+
43
+ @pytest.mark.parametrize(
44
+ "desc, req, expected_meta",
45
+ [
46
+ (
47
+ "handles simple requirement",
48
+ "pytorch",
49
+ RequirementMeta(
50
+ requirement="pytorch",
51
+ name="pytorch",
52
+ is_url_based_requirement=False,
53
+ egg_tag=None,
54
+ ),
55
+ ),
56
+ (
57
+ "handles python package with version",
58
+ "pytorch==1.0",
59
+ RequirementMeta(
60
+ requirement="pytorch==1.0",
61
+ name="pytorch",
62
+ is_url_based_requirement=False,
63
+ egg_tag=None,
64
+ ),
65
+ ),
66
+ (
67
+ "handles url-based requirement with egg tag",
68
+ "git+https://github.com/huggingface/transformers.git@main#egg=transformers",
69
+ RequirementMeta(
70
+ requirement="git+https://github.com/huggingface/transformers.git@main#egg=transformers",
71
+ name="git+github.com/huggingface/transformers.git",
72
+ is_url_based_requirement=True,
73
+ egg_tag=["transformers"],
74
+ ),
75
+ ),
76
+ (
77
+ "handles url-based requirement without egg tag",
78
+ "git+https://github.com/huggingface/transformers.git",
79
+ RequirementMeta(
80
+ requirement="git+https://github.com/huggingface/transformers.git",
81
+ name="git+github.com/huggingface/transformers.git",
82
+ is_url_based_requirement=True,
83
+ egg_tag=None,
84
+ ),
85
+ ),
86
+ ],
87
+ )
88
+ def test_requirement_meta_from_req(desc, req: str, expected_meta: RequirementMeta):
89
+ assert expected_meta == RequirementMeta.from_req(req), desc