llama-stack 0.4.3__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (311) hide show
  1. llama_stack/cli/stack/_list_deps.py +11 -7
  2. llama_stack/cli/stack/run.py +3 -25
  3. llama_stack/core/access_control/datatypes.py +78 -0
  4. llama_stack/core/configure.py +2 -2
  5. {llama_stack_api/internal → llama_stack/core/connectors}/__init__.py +2 -2
  6. llama_stack/core/connectors/connectors.py +162 -0
  7. llama_stack/core/conversations/conversations.py +61 -58
  8. llama_stack/core/datatypes.py +54 -8
  9. llama_stack/core/library_client.py +60 -13
  10. llama_stack/core/prompts/prompts.py +43 -42
  11. llama_stack/core/routers/datasets.py +20 -17
  12. llama_stack/core/routers/eval_scoring.py +143 -53
  13. llama_stack/core/routers/inference.py +20 -9
  14. llama_stack/core/routers/safety.py +30 -42
  15. llama_stack/core/routers/vector_io.py +15 -7
  16. llama_stack/core/routing_tables/models.py +42 -3
  17. llama_stack/core/routing_tables/scoring_functions.py +19 -19
  18. llama_stack/core/routing_tables/shields.py +20 -17
  19. llama_stack/core/routing_tables/vector_stores.py +8 -5
  20. llama_stack/core/server/auth.py +192 -17
  21. llama_stack/core/server/fastapi_router_registry.py +40 -5
  22. llama_stack/core/server/server.py +24 -5
  23. llama_stack/core/stack.py +54 -10
  24. llama_stack/core/storage/datatypes.py +9 -0
  25. llama_stack/core/store/registry.py +1 -1
  26. llama_stack/core/utils/exec.py +2 -2
  27. llama_stack/core/utils/type_inspection.py +16 -2
  28. llama_stack/distributions/dell/config.yaml +4 -1
  29. llama_stack/distributions/dell/doc_template.md +209 -0
  30. llama_stack/distributions/dell/run-with-safety.yaml +4 -1
  31. llama_stack/distributions/nvidia/config.yaml +4 -1
  32. llama_stack/distributions/nvidia/doc_template.md +170 -0
  33. llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
  34. llama_stack/distributions/oci/config.yaml +4 -1
  35. llama_stack/distributions/oci/doc_template.md +140 -0
  36. llama_stack/distributions/open-benchmark/config.yaml +9 -1
  37. llama_stack/distributions/postgres-demo/config.yaml +1 -1
  38. llama_stack/distributions/starter/build.yaml +62 -0
  39. llama_stack/distributions/starter/config.yaml +22 -3
  40. llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
  41. llama_stack/distributions/starter/starter.py +13 -1
  42. llama_stack/distributions/starter-gpu/build.yaml +62 -0
  43. llama_stack/distributions/starter-gpu/config.yaml +22 -3
  44. llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
  45. llama_stack/distributions/template.py +10 -2
  46. llama_stack/distributions/watsonx/config.yaml +4 -1
  47. llama_stack/log.py +1 -0
  48. llama_stack/models/llama/resources/dog.jpg +0 -0
  49. llama_stack/models/llama/resources/pasta.jpeg +0 -0
  50. llama_stack/models/llama/resources/small_dog.jpg +0 -0
  51. llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
  52. llama_stack/providers/inline/agents/meta_reference/agents.py +58 -61
  53. llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +187 -60
  54. llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +99 -22
  55. llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
  56. llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
  57. llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
  58. llama_stack/providers/inline/batches/reference/batches.py +2 -1
  59. llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
  60. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.h +9 -0
  61. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.swift +189 -0
  62. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/Parsing.swift +238 -0
  63. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/PromptTemplate.swift +12 -0
  64. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/SystemPrompts.swift +89 -0
  65. llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.pbxproj +550 -0
  66. llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/contents.xcworkspacedata +7 -0
  67. llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +8 -0
  68. llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
  69. llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
  70. llama_stack/providers/inline/post_training/torchtune/common/utils.py +5 -9
  71. llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
  72. llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
  73. llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
  74. llama_stack/providers/inline/safety/llama_guard/llama_guard.py +20 -24
  75. llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
  76. llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
  77. llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
  78. llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
  79. llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
  80. llama_stack/providers/registry/agents.py +1 -0
  81. llama_stack/providers/registry/inference.py +1 -9
  82. llama_stack/providers/registry/vector_io.py +136 -16
  83. llama_stack/providers/remote/datasetio/nvidia/README.md +74 -0
  84. llama_stack/providers/remote/eval/nvidia/README.md +134 -0
  85. llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
  86. llama_stack/providers/remote/files/s3/README.md +266 -0
  87. llama_stack/providers/remote/files/s3/config.py +5 -3
  88. llama_stack/providers/remote/files/s3/files.py +2 -2
  89. llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
  90. llama_stack/providers/remote/inference/nvidia/NVIDIA.md +203 -0
  91. llama_stack/providers/remote/inference/openai/openai.py +2 -0
  92. llama_stack/providers/remote/inference/together/together.py +4 -0
  93. llama_stack/providers/remote/inference/vertexai/config.py +3 -3
  94. llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
  95. llama_stack/providers/remote/inference/vllm/config.py +37 -18
  96. llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
  97. llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
  98. llama_stack/providers/remote/post_training/nvidia/README.md +151 -0
  99. llama_stack/providers/remote/post_training/nvidia/models.py +3 -11
  100. llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
  101. llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
  102. llama_stack/providers/remote/safety/nvidia/README.md +78 -0
  103. llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
  104. llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
  105. llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
  106. llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
  107. llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
  108. llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
  109. llama_stack/providers/remote/vector_io/oci/config.py +41 -0
  110. llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
  111. llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
  112. llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
  113. llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
  114. llama_stack/providers/utils/bedrock/client.py +3 -3
  115. llama_stack/providers/utils/bedrock/config.py +7 -7
  116. llama_stack/providers/utils/inference/__init__.py +0 -25
  117. llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
  118. llama_stack/providers/utils/inference/http_client.py +239 -0
  119. llama_stack/providers/utils/inference/litellm_openai_mixin.py +6 -0
  120. llama_stack/providers/utils/inference/model_registry.py +148 -2
  121. llama_stack/providers/utils/inference/openai_compat.py +1 -158
  122. llama_stack/providers/utils/inference/openai_mixin.py +42 -2
  123. llama_stack/providers/utils/inference/prompt_adapter.py +0 -209
  124. llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
  125. llama_stack/providers/utils/memory/vector_store.py +46 -19
  126. llama_stack/providers/utils/responses/responses_store.py +40 -6
  127. llama_stack/providers/utils/safety.py +114 -0
  128. llama_stack/providers/utils/tools/mcp.py +44 -3
  129. llama_stack/testing/api_recorder.py +9 -3
  130. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/METADATA +14 -2
  131. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/RECORD +135 -279
  132. llama_stack-0.5.0.dist-info/top_level.txt +1 -0
  133. llama_stack/distributions/meta-reference-gpu/__init__.py +0 -7
  134. llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
  135. llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
  136. llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
  137. llama_stack/models/llama/hadamard_utils.py +0 -88
  138. llama_stack/models/llama/llama3/args.py +0 -74
  139. llama_stack/models/llama/llama3/generation.py +0 -378
  140. llama_stack/models/llama/llama3/model.py +0 -304
  141. llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
  142. llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
  143. llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
  144. llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
  145. llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
  146. llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
  147. llama_stack/models/llama/llama3/quantization/loader.py +0 -316
  148. llama_stack/models/llama/llama3_1/__init__.py +0 -12
  149. llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
  150. llama_stack/models/llama/llama3_1/prompts.py +0 -258
  151. llama_stack/models/llama/llama3_2/__init__.py +0 -5
  152. llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
  153. llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
  154. llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
  155. llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
  156. llama_stack/models/llama/llama3_3/__init__.py +0 -5
  157. llama_stack/models/llama/llama3_3/prompts.py +0 -259
  158. llama_stack/models/llama/llama4/args.py +0 -107
  159. llama_stack/models/llama/llama4/ffn.py +0 -58
  160. llama_stack/models/llama/llama4/moe.py +0 -214
  161. llama_stack/models/llama/llama4/preprocess.py +0 -435
  162. llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
  163. llama_stack/models/llama/llama4/quantization/loader.py +0 -226
  164. llama_stack/models/llama/llama4/vision/__init__.py +0 -5
  165. llama_stack/models/llama/llama4/vision/embedding.py +0 -210
  166. llama_stack/models/llama/llama4/vision/encoder.py +0 -412
  167. llama_stack/models/llama/quantize_impls.py +0 -316
  168. llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
  169. llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
  170. llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
  171. llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
  172. llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
  173. llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
  174. llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
  175. llama_stack-0.4.3.dist-info/top_level.txt +0 -2
  176. llama_stack_api/__init__.py +0 -945
  177. llama_stack_api/admin/__init__.py +0 -45
  178. llama_stack_api/admin/api.py +0 -72
  179. llama_stack_api/admin/fastapi_routes.py +0 -117
  180. llama_stack_api/admin/models.py +0 -113
  181. llama_stack_api/agents.py +0 -173
  182. llama_stack_api/batches/__init__.py +0 -40
  183. llama_stack_api/batches/api.py +0 -53
  184. llama_stack_api/batches/fastapi_routes.py +0 -113
  185. llama_stack_api/batches/models.py +0 -78
  186. llama_stack_api/benchmarks/__init__.py +0 -43
  187. llama_stack_api/benchmarks/api.py +0 -39
  188. llama_stack_api/benchmarks/fastapi_routes.py +0 -109
  189. llama_stack_api/benchmarks/models.py +0 -109
  190. llama_stack_api/common/__init__.py +0 -5
  191. llama_stack_api/common/content_types.py +0 -101
  192. llama_stack_api/common/errors.py +0 -95
  193. llama_stack_api/common/job_types.py +0 -38
  194. llama_stack_api/common/responses.py +0 -77
  195. llama_stack_api/common/training_types.py +0 -47
  196. llama_stack_api/common/type_system.py +0 -146
  197. llama_stack_api/connectors.py +0 -146
  198. llama_stack_api/conversations.py +0 -270
  199. llama_stack_api/datasetio.py +0 -55
  200. llama_stack_api/datasets/__init__.py +0 -61
  201. llama_stack_api/datasets/api.py +0 -35
  202. llama_stack_api/datasets/fastapi_routes.py +0 -104
  203. llama_stack_api/datasets/models.py +0 -152
  204. llama_stack_api/datatypes.py +0 -373
  205. llama_stack_api/eval.py +0 -137
  206. llama_stack_api/file_processors/__init__.py +0 -27
  207. llama_stack_api/file_processors/api.py +0 -64
  208. llama_stack_api/file_processors/fastapi_routes.py +0 -78
  209. llama_stack_api/file_processors/models.py +0 -42
  210. llama_stack_api/files/__init__.py +0 -35
  211. llama_stack_api/files/api.py +0 -51
  212. llama_stack_api/files/fastapi_routes.py +0 -124
  213. llama_stack_api/files/models.py +0 -107
  214. llama_stack_api/inference.py +0 -1169
  215. llama_stack_api/inspect_api/__init__.py +0 -37
  216. llama_stack_api/inspect_api/api.py +0 -25
  217. llama_stack_api/inspect_api/fastapi_routes.py +0 -76
  218. llama_stack_api/inspect_api/models.py +0 -28
  219. llama_stack_api/internal/kvstore.py +0 -28
  220. llama_stack_api/internal/sqlstore.py +0 -81
  221. llama_stack_api/llama_stack_api/__init__.py +0 -945
  222. llama_stack_api/llama_stack_api/admin/__init__.py +0 -45
  223. llama_stack_api/llama_stack_api/admin/api.py +0 -72
  224. llama_stack_api/llama_stack_api/admin/fastapi_routes.py +0 -117
  225. llama_stack_api/llama_stack_api/admin/models.py +0 -113
  226. llama_stack_api/llama_stack_api/agents.py +0 -173
  227. llama_stack_api/llama_stack_api/batches/__init__.py +0 -40
  228. llama_stack_api/llama_stack_api/batches/api.py +0 -53
  229. llama_stack_api/llama_stack_api/batches/fastapi_routes.py +0 -113
  230. llama_stack_api/llama_stack_api/batches/models.py +0 -78
  231. llama_stack_api/llama_stack_api/benchmarks/__init__.py +0 -43
  232. llama_stack_api/llama_stack_api/benchmarks/api.py +0 -39
  233. llama_stack_api/llama_stack_api/benchmarks/fastapi_routes.py +0 -109
  234. llama_stack_api/llama_stack_api/benchmarks/models.py +0 -109
  235. llama_stack_api/llama_stack_api/common/__init__.py +0 -5
  236. llama_stack_api/llama_stack_api/common/content_types.py +0 -101
  237. llama_stack_api/llama_stack_api/common/errors.py +0 -95
  238. llama_stack_api/llama_stack_api/common/job_types.py +0 -38
  239. llama_stack_api/llama_stack_api/common/responses.py +0 -77
  240. llama_stack_api/llama_stack_api/common/training_types.py +0 -47
  241. llama_stack_api/llama_stack_api/common/type_system.py +0 -146
  242. llama_stack_api/llama_stack_api/connectors.py +0 -146
  243. llama_stack_api/llama_stack_api/conversations.py +0 -270
  244. llama_stack_api/llama_stack_api/datasetio.py +0 -55
  245. llama_stack_api/llama_stack_api/datasets/__init__.py +0 -61
  246. llama_stack_api/llama_stack_api/datasets/api.py +0 -35
  247. llama_stack_api/llama_stack_api/datasets/fastapi_routes.py +0 -104
  248. llama_stack_api/llama_stack_api/datasets/models.py +0 -152
  249. llama_stack_api/llama_stack_api/datatypes.py +0 -373
  250. llama_stack_api/llama_stack_api/eval.py +0 -137
  251. llama_stack_api/llama_stack_api/file_processors/__init__.py +0 -27
  252. llama_stack_api/llama_stack_api/file_processors/api.py +0 -64
  253. llama_stack_api/llama_stack_api/file_processors/fastapi_routes.py +0 -78
  254. llama_stack_api/llama_stack_api/file_processors/models.py +0 -42
  255. llama_stack_api/llama_stack_api/files/__init__.py +0 -35
  256. llama_stack_api/llama_stack_api/files/api.py +0 -51
  257. llama_stack_api/llama_stack_api/files/fastapi_routes.py +0 -124
  258. llama_stack_api/llama_stack_api/files/models.py +0 -107
  259. llama_stack_api/llama_stack_api/inference.py +0 -1169
  260. llama_stack_api/llama_stack_api/inspect_api/__init__.py +0 -37
  261. llama_stack_api/llama_stack_api/inspect_api/api.py +0 -25
  262. llama_stack_api/llama_stack_api/inspect_api/fastapi_routes.py +0 -76
  263. llama_stack_api/llama_stack_api/inspect_api/models.py +0 -28
  264. llama_stack_api/llama_stack_api/internal/__init__.py +0 -9
  265. llama_stack_api/llama_stack_api/internal/kvstore.py +0 -28
  266. llama_stack_api/llama_stack_api/internal/sqlstore.py +0 -81
  267. llama_stack_api/llama_stack_api/models.py +0 -171
  268. llama_stack_api/llama_stack_api/openai_responses.py +0 -1468
  269. llama_stack_api/llama_stack_api/post_training.py +0 -370
  270. llama_stack_api/llama_stack_api/prompts.py +0 -203
  271. llama_stack_api/llama_stack_api/providers/__init__.py +0 -33
  272. llama_stack_api/llama_stack_api/providers/api.py +0 -16
  273. llama_stack_api/llama_stack_api/providers/fastapi_routes.py +0 -57
  274. llama_stack_api/llama_stack_api/providers/models.py +0 -24
  275. llama_stack_api/llama_stack_api/py.typed +0 -0
  276. llama_stack_api/llama_stack_api/rag_tool.py +0 -168
  277. llama_stack_api/llama_stack_api/resource.py +0 -37
  278. llama_stack_api/llama_stack_api/router_utils.py +0 -160
  279. llama_stack_api/llama_stack_api/safety.py +0 -132
  280. llama_stack_api/llama_stack_api/schema_utils.py +0 -208
  281. llama_stack_api/llama_stack_api/scoring.py +0 -93
  282. llama_stack_api/llama_stack_api/scoring_functions.py +0 -211
  283. llama_stack_api/llama_stack_api/shields.py +0 -93
  284. llama_stack_api/llama_stack_api/tools.py +0 -226
  285. llama_stack_api/llama_stack_api/vector_io.py +0 -941
  286. llama_stack_api/llama_stack_api/vector_stores.py +0 -53
  287. llama_stack_api/llama_stack_api/version.py +0 -9
  288. llama_stack_api/models.py +0 -171
  289. llama_stack_api/openai_responses.py +0 -1468
  290. llama_stack_api/post_training.py +0 -370
  291. llama_stack_api/prompts.py +0 -203
  292. llama_stack_api/providers/__init__.py +0 -33
  293. llama_stack_api/providers/api.py +0 -16
  294. llama_stack_api/providers/fastapi_routes.py +0 -57
  295. llama_stack_api/providers/models.py +0 -24
  296. llama_stack_api/py.typed +0 -0
  297. llama_stack_api/rag_tool.py +0 -168
  298. llama_stack_api/resource.py +0 -37
  299. llama_stack_api/router_utils.py +0 -160
  300. llama_stack_api/safety.py +0 -132
  301. llama_stack_api/schema_utils.py +0 -208
  302. llama_stack_api/scoring.py +0 -93
  303. llama_stack_api/scoring_functions.py +0 -211
  304. llama_stack_api/shields.py +0 -93
  305. llama_stack_api/tools.py +0 -226
  306. llama_stack_api/vector_io.py +0 -941
  307. llama_stack_api/vector_stores.py +0 -53
  308. llama_stack_api/version.py +0 -9
  309. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/WHEEL +0 -0
  310. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/entry_points.txt +0 -0
  311. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,353 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- # Copyright (c) Meta Platforms, IAny, nc. and affiliates.
8
- # All rights reserved.
9
- #
10
- # This source code is licensed under the terms described in the LICENSE file in
11
- # the root directory of this source tree.
12
-
13
- import copy
14
- import json
15
- import multiprocessing
16
- import os
17
- import tempfile
18
- import time
19
- import uuid
20
- from collections.abc import Callable, Generator
21
- from enum import Enum
22
- from typing import Annotated, Literal
23
-
24
- import torch
25
- import zmq
26
- from fairscale.nn.model_parallel.initialize import (
27
- get_model_parallel_group,
28
- get_model_parallel_rank,
29
- get_model_parallel_src_rank,
30
- )
31
- from pydantic import BaseModel, Field
32
- from torch.distributed.launcher.api import LaunchConfig, elastic_launch
33
-
34
- from llama_stack.log import get_logger
35
- from llama_stack.models.llama.datatypes import GenerationResult
36
-
37
- log = get_logger(name=__name__, category="inference")
38
-
39
-
40
- class ProcessingMessageName(str, Enum):
41
- ready_request = "ready_request"
42
- ready_response = "ready_response"
43
- end_sentinel = "end_sentinel"
44
- cancel_sentinel = "cancel_sentinel"
45
- task_request = "task_request"
46
- task_response = "task_response"
47
- exception_response = "exception_response"
48
-
49
-
50
- class ReadyRequest(BaseModel):
51
- type: Literal[ProcessingMessageName.ready_request] = ProcessingMessageName.ready_request
52
-
53
-
54
- class ReadyResponse(BaseModel):
55
- type: Literal[ProcessingMessageName.ready_response] = ProcessingMessageName.ready_response
56
-
57
-
58
- class EndSentinel(BaseModel):
59
- type: Literal[ProcessingMessageName.end_sentinel] = ProcessingMessageName.end_sentinel
60
-
61
-
62
- class CancelSentinel(BaseModel):
63
- type: Literal[ProcessingMessageName.cancel_sentinel] = ProcessingMessageName.cancel_sentinel
64
-
65
-
66
- class TaskRequest(BaseModel):
67
- type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
68
- task: tuple[str, list]
69
-
70
-
71
- class TaskResponse(BaseModel):
72
- type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response
73
- result: list[GenerationResult]
74
-
75
-
76
- class ExceptionResponse(BaseModel):
77
- type: Literal[ProcessingMessageName.exception_response] = ProcessingMessageName.exception_response
78
- error: str
79
-
80
-
81
- ProcessingMessage = (
82
- ReadyRequest | ReadyResponse | EndSentinel | CancelSentinel | TaskRequest | TaskResponse | ExceptionResponse
83
- )
84
-
85
-
86
- class ProcessingMessageWrapper(BaseModel):
87
- payload: Annotated[
88
- ProcessingMessage,
89
- Field(discriminator="type"),
90
- ]
91
-
92
-
93
- def mp_rank_0() -> bool:
94
- return bool(get_model_parallel_rank() == 0)
95
-
96
-
97
- def encode_msg(msg: ProcessingMessage) -> bytes:
98
- return ProcessingMessageWrapper(payload=msg).model_dump_json().encode("utf-8")
99
-
100
-
101
- def retrieve_requests(reply_socket_url: str):
102
- if mp_rank_0():
103
- context = zmq.Context()
104
- reply_socket = context.socket(zmq.ROUTER)
105
- reply_socket.connect(reply_socket_url)
106
-
107
- while True:
108
- client_id, obj = maybe_get_work(reply_socket)
109
- if obj is None:
110
- time.sleep(0.01)
111
- continue
112
-
113
- ready_response = ReadyResponse()
114
- reply_socket.send_multipart([client_id, encode_msg(ready_response)])
115
- break
116
-
117
- def send_obj(obj: ProcessingMessage):
118
- reply_socket.send_multipart([client_id, encode_msg(obj)])
119
-
120
- while True:
121
- tasks: list[ProcessingMessage | None] = [None]
122
- if mp_rank_0():
123
- client_id, maybe_task_json = maybe_get_work(reply_socket)
124
- if maybe_task_json is not None:
125
- task = maybe_parse_message(maybe_task_json)
126
- # there is still an unknown unclean GeneratorExit happening resulting in a
127
- # cancel sentinel getting queued _after_ we have finished sending everything :/
128
- # kind of a hack this is :/
129
- if task is not None and not isinstance(task, CancelSentinel):
130
- tasks = [task]
131
-
132
- torch.distributed.broadcast_object_list(
133
- tasks,
134
- src=get_model_parallel_src_rank(),
135
- group=get_model_parallel_group(),
136
- )
137
-
138
- task = tasks[0]
139
- if task is None:
140
- time.sleep(0.1)
141
- else:
142
- try:
143
- out = yield task
144
- if out is None:
145
- break
146
-
147
- for obj in out:
148
- updates: list[ProcessingMessage | None] = [None]
149
- if mp_rank_0():
150
- _, update_json = maybe_get_work(reply_socket)
151
- update = maybe_parse_message(update_json)
152
- if isinstance(update, CancelSentinel):
153
- updates = [update]
154
- else:
155
- # only send the update if it's not cancelled otherwise the object sits in the socket
156
- # and gets pulled in the next request lol
157
- send_obj(TaskResponse(result=obj))
158
-
159
- torch.distributed.broadcast_object_list(
160
- updates,
161
- src=get_model_parallel_src_rank(),
162
- group=get_model_parallel_group(),
163
- )
164
- if isinstance(updates[0], CancelSentinel):
165
- log.info("quitting generation loop because request was cancelled")
166
- break
167
-
168
- if mp_rank_0():
169
- send_obj(EndSentinel())
170
- except Exception as e:
171
- log.exception("exception in generation loop")
172
-
173
- if mp_rank_0():
174
- send_obj(ExceptionResponse(error=str(e)))
175
-
176
- if mp_rank_0():
177
- send_obj(EndSentinel())
178
-
179
-
180
- def maybe_get_work(sock: zmq.Socket):
181
- message = None
182
- client_id = None
183
- try:
184
- client_id, obj = sock.recv_multipart(zmq.NOBLOCK)
185
- message = obj.decode("utf-8")
186
- except zmq.ZMQError as e:
187
- if e.errno != zmq.EAGAIN:
188
- raise e
189
-
190
- return client_id, message
191
-
192
-
193
- def maybe_parse_message(maybe_json: str | None) -> ProcessingMessage | None:
194
- if maybe_json is None:
195
- return None
196
- try:
197
- return parse_message(maybe_json)
198
- except json.JSONDecodeError:
199
- return None
200
- except ValueError:
201
- return None
202
-
203
-
204
- def parse_message(json_str: str) -> ProcessingMessage:
205
- data = json.loads(json_str)
206
- return copy.deepcopy(ProcessingMessageWrapper(**data).payload)
207
-
208
-
209
- def worker_process_entrypoint(
210
- reply_socket_url: str,
211
- init_model_cb: Callable,
212
- ) -> None:
213
- model = init_model_cb()
214
- torch.distributed.barrier()
215
- time.sleep(1)
216
-
217
- # run the requests co-routine which retrieves requests from the socket
218
- # and sends responses (we provide) back to the caller
219
- req_gen = retrieve_requests(reply_socket_url)
220
- result = None
221
- while True:
222
- try:
223
- task = req_gen.send(result)
224
- if isinstance(task, EndSentinel):
225
- break
226
-
227
- assert isinstance(task, TaskRequest), task
228
- result = model(task.task)
229
- except StopIteration:
230
- break
231
-
232
- log.info("[debug] worker process done")
233
-
234
-
235
- def launch_dist_group(
236
- reply_socket_url: str,
237
- model_parallel_size: int,
238
- init_model_cb: Callable,
239
- **kwargs,
240
- ) -> None:
241
- with tempfile.TemporaryDirectory() as tmpdir:
242
- # TODO: track workers and if they terminate, tell parent process about it so cleanup can happen
243
- launch_config = LaunchConfig(
244
- max_nodes=1,
245
- min_nodes=1,
246
- nproc_per_node=model_parallel_size,
247
- start_method="fork",
248
- rdzv_backend="c10d",
249
- rdzv_endpoint=os.path.join(tmpdir, "rdzv"),
250
- rdzv_configs={"store_type": "file", "timeout": 90},
251
- max_restarts=0,
252
- monitor_interval=1,
253
- run_id=str(uuid.uuid4()),
254
- )
255
- elastic_launch(launch_config, entrypoint=worker_process_entrypoint)(
256
- reply_socket_url,
257
- init_model_cb,
258
- )
259
-
260
-
261
- def start_model_parallel_process(
262
- model_parallel_size: int,
263
- init_model_cb: Callable,
264
- **kwargs,
265
- ):
266
- context = zmq.Context()
267
- request_socket = context.socket(zmq.DEALER)
268
-
269
- # Binding the request socket to a random port
270
- request_socket.bind("tcp://127.0.0.1:0")
271
-
272
- main_process_url = request_socket.getsockopt_string(zmq.LAST_ENDPOINT)
273
-
274
- ctx = multiprocessing.get_context("spawn")
275
- process = ctx.Process(
276
- target=launch_dist_group,
277
- args=(
278
- main_process_url,
279
- model_parallel_size,
280
- init_model_cb,
281
- ),
282
- kwargs=kwargs,
283
- )
284
- process.start()
285
-
286
- # wait until the model is loaded; rank 0 will send a message to indicate it's ready
287
-
288
- request_socket.send(encode_msg(ReadyRequest()))
289
- _response = request_socket.recv()
290
- log.info("Loaded model...")
291
-
292
- return request_socket, process
293
-
294
-
295
- class ModelParallelProcessGroup:
296
- def __init__(
297
- self,
298
- model_parallel_size: int,
299
- init_model_cb: Callable,
300
- **kwargs,
301
- ):
302
- self.model_parallel_size = model_parallel_size
303
- self.init_model_cb = init_model_cb
304
- self.started = False
305
- self.running = False
306
-
307
- def start(self):
308
- assert not self.started, "process group already started"
309
- self.request_socket, self.process = start_model_parallel_process(
310
- self.model_parallel_size,
311
- self.init_model_cb,
312
- )
313
- self.started = True
314
-
315
- def stop(self):
316
- assert self.started, "process group not started"
317
- if self.process.is_alive():
318
- self.request_socket.send(encode_msg(EndSentinel()), zmq.NOBLOCK)
319
- self.process.join()
320
- self.started = False
321
-
322
- def run_inference(
323
- self,
324
- req: tuple[str, list],
325
- ) -> Generator:
326
- assert not self.running, "inference already running"
327
-
328
- self.running = True
329
- try:
330
- self.request_socket.send(encode_msg(TaskRequest(task=req)))
331
- while True:
332
- obj_json = self.request_socket.recv()
333
- obj = parse_message(obj_json)
334
-
335
- if isinstance(obj, EndSentinel):
336
- break
337
-
338
- if isinstance(obj, ExceptionResponse):
339
- log.error(f"[debug] got exception {obj.error}")
340
- raise Exception(obj.error)
341
-
342
- if isinstance(obj, TaskResponse):
343
- yield obj.result
344
-
345
- except GeneratorExit:
346
- self.request_socket.send(encode_msg(CancelSentinel()))
347
- while True:
348
- obj_json = self.request_socket.send()
349
- obj = parse_message(obj_json)
350
- if isinstance(obj, EndSentinel):
351
- break
352
- finally:
353
- self.running = False
@@ -1,2 +0,0 @@
1
- llama_stack
2
- llama_stack_api