llama-stack 0.3.5__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (458) hide show
  1. llama_stack/__init__.py +0 -5
  2. llama_stack/cli/llama.py +3 -3
  3. llama_stack/cli/stack/_list_deps.py +12 -23
  4. llama_stack/cli/stack/list_stacks.py +37 -18
  5. llama_stack/cli/stack/run.py +121 -11
  6. llama_stack/cli/stack/utils.py +0 -127
  7. llama_stack/core/access_control/access_control.py +69 -28
  8. llama_stack/core/access_control/conditions.py +15 -5
  9. llama_stack/core/admin.py +267 -0
  10. llama_stack/core/build.py +6 -74
  11. llama_stack/core/client.py +1 -1
  12. llama_stack/core/configure.py +6 -6
  13. llama_stack/core/conversations/conversations.py +28 -25
  14. llama_stack/core/datatypes.py +271 -79
  15. llama_stack/core/distribution.py +15 -16
  16. llama_stack/core/external.py +3 -3
  17. llama_stack/core/inspect.py +98 -15
  18. llama_stack/core/library_client.py +73 -61
  19. llama_stack/core/prompts/prompts.py +12 -11
  20. llama_stack/core/providers.py +17 -11
  21. llama_stack/core/resolver.py +65 -56
  22. llama_stack/core/routers/__init__.py +8 -12
  23. llama_stack/core/routers/datasets.py +1 -4
  24. llama_stack/core/routers/eval_scoring.py +7 -4
  25. llama_stack/core/routers/inference.py +55 -271
  26. llama_stack/core/routers/safety.py +52 -24
  27. llama_stack/core/routers/tool_runtime.py +6 -48
  28. llama_stack/core/routers/vector_io.py +130 -51
  29. llama_stack/core/routing_tables/benchmarks.py +24 -20
  30. llama_stack/core/routing_tables/common.py +1 -4
  31. llama_stack/core/routing_tables/datasets.py +22 -22
  32. llama_stack/core/routing_tables/models.py +119 -6
  33. llama_stack/core/routing_tables/scoring_functions.py +7 -7
  34. llama_stack/core/routing_tables/shields.py +1 -2
  35. llama_stack/core/routing_tables/toolgroups.py +17 -7
  36. llama_stack/core/routing_tables/vector_stores.py +51 -16
  37. llama_stack/core/server/auth.py +5 -3
  38. llama_stack/core/server/auth_providers.py +36 -20
  39. llama_stack/core/server/fastapi_router_registry.py +84 -0
  40. llama_stack/core/server/quota.py +2 -2
  41. llama_stack/core/server/routes.py +79 -27
  42. llama_stack/core/server/server.py +102 -87
  43. llama_stack/core/stack.py +201 -58
  44. llama_stack/core/storage/datatypes.py +26 -3
  45. llama_stack/{providers/utils → core/storage}/kvstore/__init__.py +2 -0
  46. llama_stack/{providers/utils → core/storage}/kvstore/kvstore.py +55 -24
  47. llama_stack/{providers/utils → core/storage}/kvstore/mongodb/mongodb.py +13 -10
  48. llama_stack/{providers/utils → core/storage}/kvstore/postgres/postgres.py +28 -17
  49. llama_stack/{providers/utils → core/storage}/kvstore/redis/redis.py +41 -16
  50. llama_stack/{providers/utils → core/storage}/kvstore/sqlite/sqlite.py +1 -1
  51. llama_stack/core/storage/sqlstore/__init__.py +17 -0
  52. llama_stack/{providers/utils → core/storage}/sqlstore/authorized_sqlstore.py +69 -49
  53. llama_stack/{providers/utils → core/storage}/sqlstore/sqlalchemy_sqlstore.py +47 -17
  54. llama_stack/{providers/utils → core/storage}/sqlstore/sqlstore.py +25 -8
  55. llama_stack/core/store/registry.py +1 -1
  56. llama_stack/core/utils/config.py +8 -2
  57. llama_stack/core/utils/config_resolution.py +32 -29
  58. llama_stack/core/utils/context.py +4 -10
  59. llama_stack/core/utils/exec.py +9 -0
  60. llama_stack/core/utils/type_inspection.py +45 -0
  61. llama_stack/distributions/dell/{run.yaml → config.yaml} +3 -2
  62. llama_stack/distributions/dell/dell.py +2 -2
  63. llama_stack/distributions/dell/run-with-safety.yaml +3 -2
  64. llama_stack/distributions/meta-reference-gpu/{run.yaml → config.yaml} +3 -2
  65. llama_stack/distributions/meta-reference-gpu/meta_reference.py +2 -2
  66. llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +3 -2
  67. llama_stack/distributions/nvidia/{run.yaml → config.yaml} +4 -4
  68. llama_stack/distributions/nvidia/nvidia.py +1 -1
  69. llama_stack/distributions/nvidia/run-with-safety.yaml +4 -4
  70. llama_stack/{apis/datasetio → distributions/oci}/__init__.py +1 -1
  71. llama_stack/distributions/oci/config.yaml +134 -0
  72. llama_stack/distributions/oci/oci.py +108 -0
  73. llama_stack/distributions/open-benchmark/{run.yaml → config.yaml} +5 -4
  74. llama_stack/distributions/open-benchmark/open_benchmark.py +2 -3
  75. llama_stack/distributions/postgres-demo/{run.yaml → config.yaml} +4 -3
  76. llama_stack/distributions/starter/{run.yaml → config.yaml} +64 -13
  77. llama_stack/distributions/starter/run-with-postgres-store.yaml +64 -13
  78. llama_stack/distributions/starter/starter.py +8 -5
  79. llama_stack/distributions/starter-gpu/{run.yaml → config.yaml} +64 -13
  80. llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +64 -13
  81. llama_stack/distributions/template.py +13 -69
  82. llama_stack/distributions/watsonx/{run.yaml → config.yaml} +4 -3
  83. llama_stack/distributions/watsonx/watsonx.py +1 -1
  84. llama_stack/log.py +28 -11
  85. llama_stack/models/llama/checkpoint.py +6 -6
  86. llama_stack/models/llama/hadamard_utils.py +2 -0
  87. llama_stack/models/llama/llama3/generation.py +3 -1
  88. llama_stack/models/llama/llama3/interface.py +2 -5
  89. llama_stack/models/llama/llama3/multimodal/encoder_utils.py +3 -3
  90. llama_stack/models/llama/llama3/multimodal/image_transform.py +6 -6
  91. llama_stack/models/llama/llama3/prompt_templates/system_prompts.py +1 -1
  92. llama_stack/models/llama/llama3/tool_utils.py +2 -1
  93. llama_stack/models/llama/llama4/prompt_templates/system_prompts.py +1 -1
  94. llama_stack/providers/inline/agents/meta_reference/__init__.py +3 -3
  95. llama_stack/providers/inline/agents/meta_reference/agents.py +44 -261
  96. llama_stack/providers/inline/agents/meta_reference/config.py +6 -1
  97. llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +207 -57
  98. llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +308 -47
  99. llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py +162 -96
  100. llama_stack/providers/inline/agents/meta_reference/responses/types.py +23 -8
  101. llama_stack/providers/inline/agents/meta_reference/responses/utils.py +201 -33
  102. llama_stack/providers/inline/agents/meta_reference/safety.py +8 -13
  103. llama_stack/providers/inline/batches/reference/__init__.py +2 -4
  104. llama_stack/providers/inline/batches/reference/batches.py +78 -60
  105. llama_stack/providers/inline/datasetio/localfs/datasetio.py +2 -5
  106. llama_stack/providers/inline/eval/meta_reference/eval.py +16 -61
  107. llama_stack/providers/inline/files/localfs/files.py +37 -28
  108. llama_stack/providers/inline/inference/meta_reference/config.py +2 -2
  109. llama_stack/providers/inline/inference/meta_reference/generators.py +50 -60
  110. llama_stack/providers/inline/inference/meta_reference/inference.py +403 -19
  111. llama_stack/providers/inline/inference/meta_reference/model_parallel.py +7 -26
  112. llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +2 -12
  113. llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +10 -15
  114. llama_stack/providers/inline/post_training/common/validator.py +1 -5
  115. llama_stack/providers/inline/post_training/huggingface/post_training.py +8 -8
  116. llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py +18 -10
  117. llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device_dpo.py +12 -9
  118. llama_stack/providers/inline/post_training/huggingface/utils.py +27 -6
  119. llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py +1 -1
  120. llama_stack/providers/inline/post_training/torchtune/common/utils.py +1 -1
  121. llama_stack/providers/inline/post_training/torchtune/datasets/format_adapter.py +1 -1
  122. llama_stack/providers/inline/post_training/torchtune/post_training.py +8 -8
  123. llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +16 -16
  124. llama_stack/providers/inline/safety/code_scanner/code_scanner.py +13 -9
  125. llama_stack/providers/inline/safety/llama_guard/llama_guard.py +18 -15
  126. llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +9 -9
  127. llama_stack/providers/inline/scoring/basic/scoring.py +6 -13
  128. llama_stack/providers/inline/scoring/basic/scoring_fn/docvqa_scoring_fn.py +1 -2
  129. llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py +1 -2
  130. llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/docvqa.py +2 -2
  131. llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/equality.py +2 -2
  132. llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/ifeval.py +2 -2
  133. llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_math_response.py +2 -2
  134. llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py +2 -2
  135. llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/subset_of.py +2 -2
  136. llama_stack/providers/inline/scoring/basic/scoring_fn/ifeval_scoring_fn.py +1 -2
  137. llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_math_response_scoring_fn.py +1 -2
  138. llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py +1 -2
  139. llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py +1 -2
  140. llama_stack/providers/inline/scoring/braintrust/braintrust.py +12 -15
  141. llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_correctness.py +2 -2
  142. llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_relevancy.py +2 -2
  143. llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_similarity.py +2 -2
  144. llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_entity_recall.py +2 -2
  145. llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_precision.py +2 -2
  146. llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_recall.py +2 -2
  147. llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_relevancy.py +2 -2
  148. llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/factuality.py +2 -2
  149. llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/faithfulness.py +2 -2
  150. llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +7 -14
  151. llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_405b_simpleqa.py +2 -2
  152. llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_base.py +1 -2
  153. llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py +1 -3
  154. llama_stack/providers/inline/tool_runtime/rag/__init__.py +1 -1
  155. llama_stack/providers/inline/tool_runtime/rag/config.py +8 -1
  156. llama_stack/providers/inline/tool_runtime/rag/context_retriever.py +7 -6
  157. llama_stack/providers/inline/tool_runtime/rag/memory.py +64 -48
  158. llama_stack/providers/inline/vector_io/chroma/__init__.py +1 -1
  159. llama_stack/providers/inline/vector_io/chroma/config.py +1 -1
  160. llama_stack/providers/inline/vector_io/faiss/__init__.py +1 -1
  161. llama_stack/providers/inline/vector_io/faiss/config.py +1 -1
  162. llama_stack/providers/inline/vector_io/faiss/faiss.py +43 -28
  163. llama_stack/providers/inline/vector_io/milvus/__init__.py +1 -1
  164. llama_stack/providers/inline/vector_io/milvus/config.py +1 -1
  165. llama_stack/providers/inline/vector_io/qdrant/__init__.py +1 -1
  166. llama_stack/providers/inline/vector_io/qdrant/config.py +1 -1
  167. llama_stack/providers/inline/vector_io/sqlite_vec/__init__.py +1 -1
  168. llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +40 -33
  169. llama_stack/providers/registry/agents.py +7 -3
  170. llama_stack/providers/registry/batches.py +1 -1
  171. llama_stack/providers/registry/datasetio.py +1 -1
  172. llama_stack/providers/registry/eval.py +1 -1
  173. llama_stack/{apis/datasets/__init__.py → providers/registry/file_processors.py} +5 -1
  174. llama_stack/providers/registry/files.py +11 -2
  175. llama_stack/providers/registry/inference.py +22 -3
  176. llama_stack/providers/registry/post_training.py +1 -1
  177. llama_stack/providers/registry/safety.py +1 -1
  178. llama_stack/providers/registry/scoring.py +1 -1
  179. llama_stack/providers/registry/tool_runtime.py +2 -2
  180. llama_stack/providers/registry/vector_io.py +7 -7
  181. llama_stack/providers/remote/datasetio/huggingface/huggingface.py +2 -5
  182. llama_stack/providers/remote/datasetio/nvidia/datasetio.py +1 -4
  183. llama_stack/providers/remote/eval/nvidia/eval.py +15 -9
  184. llama_stack/providers/remote/files/openai/__init__.py +19 -0
  185. llama_stack/providers/remote/files/openai/config.py +28 -0
  186. llama_stack/providers/remote/files/openai/files.py +253 -0
  187. llama_stack/providers/remote/files/s3/files.py +52 -30
  188. llama_stack/providers/remote/inference/anthropic/anthropic.py +2 -1
  189. llama_stack/providers/remote/inference/anthropic/config.py +1 -1
  190. llama_stack/providers/remote/inference/azure/azure.py +1 -3
  191. llama_stack/providers/remote/inference/azure/config.py +8 -7
  192. llama_stack/providers/remote/inference/bedrock/__init__.py +1 -1
  193. llama_stack/providers/remote/inference/bedrock/bedrock.py +82 -105
  194. llama_stack/providers/remote/inference/bedrock/config.py +24 -3
  195. llama_stack/providers/remote/inference/cerebras/cerebras.py +5 -5
  196. llama_stack/providers/remote/inference/cerebras/config.py +12 -5
  197. llama_stack/providers/remote/inference/databricks/config.py +13 -6
  198. llama_stack/providers/remote/inference/databricks/databricks.py +16 -6
  199. llama_stack/providers/remote/inference/fireworks/config.py +5 -5
  200. llama_stack/providers/remote/inference/fireworks/fireworks.py +1 -1
  201. llama_stack/providers/remote/inference/gemini/config.py +1 -1
  202. llama_stack/providers/remote/inference/gemini/gemini.py +13 -14
  203. llama_stack/providers/remote/inference/groq/config.py +5 -5
  204. llama_stack/providers/remote/inference/groq/groq.py +1 -1
  205. llama_stack/providers/remote/inference/llama_openai_compat/config.py +5 -5
  206. llama_stack/providers/remote/inference/llama_openai_compat/llama.py +8 -6
  207. llama_stack/providers/remote/inference/nvidia/__init__.py +1 -1
  208. llama_stack/providers/remote/inference/nvidia/config.py +21 -11
  209. llama_stack/providers/remote/inference/nvidia/nvidia.py +115 -3
  210. llama_stack/providers/remote/inference/nvidia/utils.py +1 -1
  211. llama_stack/providers/remote/inference/oci/__init__.py +17 -0
  212. llama_stack/providers/remote/inference/oci/auth.py +79 -0
  213. llama_stack/providers/remote/inference/oci/config.py +75 -0
  214. llama_stack/providers/remote/inference/oci/oci.py +162 -0
  215. llama_stack/providers/remote/inference/ollama/config.py +7 -5
  216. llama_stack/providers/remote/inference/ollama/ollama.py +17 -8
  217. llama_stack/providers/remote/inference/openai/config.py +4 -4
  218. llama_stack/providers/remote/inference/openai/openai.py +1 -1
  219. llama_stack/providers/remote/inference/passthrough/__init__.py +2 -2
  220. llama_stack/providers/remote/inference/passthrough/config.py +5 -10
  221. llama_stack/providers/remote/inference/passthrough/passthrough.py +97 -75
  222. llama_stack/providers/remote/inference/runpod/config.py +12 -5
  223. llama_stack/providers/remote/inference/runpod/runpod.py +2 -20
  224. llama_stack/providers/remote/inference/sambanova/config.py +5 -5
  225. llama_stack/providers/remote/inference/sambanova/sambanova.py +1 -1
  226. llama_stack/providers/remote/inference/tgi/config.py +7 -6
  227. llama_stack/providers/remote/inference/tgi/tgi.py +19 -11
  228. llama_stack/providers/remote/inference/together/config.py +5 -5
  229. llama_stack/providers/remote/inference/together/together.py +15 -12
  230. llama_stack/providers/remote/inference/vertexai/config.py +1 -1
  231. llama_stack/providers/remote/inference/vllm/config.py +5 -5
  232. llama_stack/providers/remote/inference/vllm/vllm.py +13 -14
  233. llama_stack/providers/remote/inference/watsonx/config.py +4 -4
  234. llama_stack/providers/remote/inference/watsonx/watsonx.py +21 -94
  235. llama_stack/providers/remote/post_training/nvidia/post_training.py +4 -4
  236. llama_stack/providers/remote/post_training/nvidia/utils.py +1 -1
  237. llama_stack/providers/remote/safety/bedrock/bedrock.py +6 -6
  238. llama_stack/providers/remote/safety/bedrock/config.py +1 -1
  239. llama_stack/providers/remote/safety/nvidia/config.py +1 -1
  240. llama_stack/providers/remote/safety/nvidia/nvidia.py +11 -5
  241. llama_stack/providers/remote/safety/sambanova/config.py +1 -1
  242. llama_stack/providers/remote/safety/sambanova/sambanova.py +6 -6
  243. llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py +11 -6
  244. llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py +12 -7
  245. llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py +8 -2
  246. llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py +57 -15
  247. llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py +11 -6
  248. llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py +11 -6
  249. llama_stack/providers/remote/vector_io/chroma/__init__.py +1 -1
  250. llama_stack/providers/remote/vector_io/chroma/chroma.py +125 -20
  251. llama_stack/providers/remote/vector_io/chroma/config.py +1 -1
  252. llama_stack/providers/remote/vector_io/milvus/__init__.py +1 -1
  253. llama_stack/providers/remote/vector_io/milvus/config.py +1 -1
  254. llama_stack/providers/remote/vector_io/milvus/milvus.py +27 -21
  255. llama_stack/providers/remote/vector_io/pgvector/__init__.py +1 -1
  256. llama_stack/providers/remote/vector_io/pgvector/config.py +1 -1
  257. llama_stack/providers/remote/vector_io/pgvector/pgvector.py +26 -18
  258. llama_stack/providers/remote/vector_io/qdrant/__init__.py +1 -1
  259. llama_stack/providers/remote/vector_io/qdrant/config.py +1 -1
  260. llama_stack/providers/remote/vector_io/qdrant/qdrant.py +141 -24
  261. llama_stack/providers/remote/vector_io/weaviate/__init__.py +1 -1
  262. llama_stack/providers/remote/vector_io/weaviate/config.py +1 -1
  263. llama_stack/providers/remote/vector_io/weaviate/weaviate.py +26 -21
  264. llama_stack/providers/utils/common/data_schema_validator.py +1 -5
  265. llama_stack/providers/utils/files/form_data.py +1 -1
  266. llama_stack/providers/utils/inference/embedding_mixin.py +1 -1
  267. llama_stack/providers/utils/inference/inference_store.py +7 -8
  268. llama_stack/providers/utils/inference/litellm_openai_mixin.py +79 -79
  269. llama_stack/providers/utils/inference/model_registry.py +1 -3
  270. llama_stack/providers/utils/inference/openai_compat.py +44 -1171
  271. llama_stack/providers/utils/inference/openai_mixin.py +68 -42
  272. llama_stack/providers/utils/inference/prompt_adapter.py +50 -265
  273. llama_stack/providers/utils/inference/stream_utils.py +23 -0
  274. llama_stack/providers/utils/memory/__init__.py +2 -0
  275. llama_stack/providers/utils/memory/file_utils.py +1 -1
  276. llama_stack/providers/utils/memory/openai_vector_store_mixin.py +181 -84
  277. llama_stack/providers/utils/memory/vector_store.py +39 -38
  278. llama_stack/providers/utils/pagination.py +1 -1
  279. llama_stack/providers/utils/responses/responses_store.py +15 -25
  280. llama_stack/providers/utils/scoring/aggregation_utils.py +1 -2
  281. llama_stack/providers/utils/scoring/base_scoring_fn.py +1 -2
  282. llama_stack/providers/utils/tools/mcp.py +93 -11
  283. llama_stack/telemetry/constants.py +27 -0
  284. llama_stack/telemetry/helpers.py +43 -0
  285. llama_stack/testing/api_recorder.py +25 -16
  286. {llama_stack-0.3.5.dist-info → llama_stack-0.4.0.dist-info}/METADATA +56 -54
  287. llama_stack-0.4.0.dist-info/RECORD +588 -0
  288. llama_stack-0.4.0.dist-info/top_level.txt +2 -0
  289. llama_stack_api/__init__.py +945 -0
  290. llama_stack_api/admin/__init__.py +45 -0
  291. llama_stack_api/admin/api.py +72 -0
  292. llama_stack_api/admin/fastapi_routes.py +117 -0
  293. llama_stack_api/admin/models.py +113 -0
  294. llama_stack_api/agents.py +173 -0
  295. llama_stack_api/batches/__init__.py +40 -0
  296. llama_stack_api/batches/api.py +53 -0
  297. llama_stack_api/batches/fastapi_routes.py +113 -0
  298. llama_stack_api/batches/models.py +78 -0
  299. llama_stack_api/benchmarks/__init__.py +43 -0
  300. llama_stack_api/benchmarks/api.py +39 -0
  301. llama_stack_api/benchmarks/fastapi_routes.py +109 -0
  302. llama_stack_api/benchmarks/models.py +109 -0
  303. {llama_stack/apis → llama_stack_api}/common/content_types.py +1 -43
  304. {llama_stack/apis → llama_stack_api}/common/errors.py +0 -8
  305. {llama_stack/apis → llama_stack_api}/common/job_types.py +1 -1
  306. llama_stack_api/common/responses.py +77 -0
  307. {llama_stack/apis → llama_stack_api}/common/training_types.py +1 -1
  308. {llama_stack/apis → llama_stack_api}/common/type_system.py +2 -14
  309. llama_stack_api/connectors.py +146 -0
  310. {llama_stack/apis/conversations → llama_stack_api}/conversations.py +23 -39
  311. {llama_stack/apis/datasetio → llama_stack_api}/datasetio.py +4 -8
  312. llama_stack_api/datasets/__init__.py +61 -0
  313. llama_stack_api/datasets/api.py +35 -0
  314. llama_stack_api/datasets/fastapi_routes.py +104 -0
  315. llama_stack_api/datasets/models.py +152 -0
  316. {llama_stack/providers → llama_stack_api}/datatypes.py +166 -10
  317. {llama_stack/apis/eval → llama_stack_api}/eval.py +8 -40
  318. llama_stack_api/file_processors/__init__.py +27 -0
  319. llama_stack_api/file_processors/api.py +64 -0
  320. llama_stack_api/file_processors/fastapi_routes.py +78 -0
  321. llama_stack_api/file_processors/models.py +42 -0
  322. llama_stack_api/files/__init__.py +35 -0
  323. llama_stack_api/files/api.py +51 -0
  324. llama_stack_api/files/fastapi_routes.py +124 -0
  325. llama_stack_api/files/models.py +107 -0
  326. {llama_stack/apis/inference → llama_stack_api}/inference.py +90 -194
  327. llama_stack_api/inspect_api/__init__.py +37 -0
  328. llama_stack_api/inspect_api/api.py +25 -0
  329. llama_stack_api/inspect_api/fastapi_routes.py +76 -0
  330. llama_stack_api/inspect_api/models.py +28 -0
  331. {llama_stack/apis/agents → llama_stack_api/internal}/__init__.py +3 -1
  332. llama_stack/providers/utils/kvstore/api.py → llama_stack_api/internal/kvstore.py +5 -0
  333. llama_stack_api/internal/sqlstore.py +79 -0
  334. {llama_stack/apis/models → llama_stack_api}/models.py +11 -9
  335. {llama_stack/apis/agents → llama_stack_api}/openai_responses.py +184 -27
  336. {llama_stack/apis/post_training → llama_stack_api}/post_training.py +7 -11
  337. {llama_stack/apis/prompts → llama_stack_api}/prompts.py +3 -4
  338. llama_stack_api/providers/__init__.py +33 -0
  339. llama_stack_api/providers/api.py +16 -0
  340. llama_stack_api/providers/fastapi_routes.py +57 -0
  341. llama_stack_api/providers/models.py +24 -0
  342. {llama_stack/apis/tools → llama_stack_api}/rag_tool.py +2 -52
  343. {llama_stack/apis → llama_stack_api}/resource.py +1 -1
  344. llama_stack_api/router_utils.py +160 -0
  345. {llama_stack/apis/safety → llama_stack_api}/safety.py +6 -9
  346. {llama_stack → llama_stack_api}/schema_utils.py +94 -4
  347. {llama_stack/apis/scoring → llama_stack_api}/scoring.py +3 -3
  348. {llama_stack/apis/scoring_functions → llama_stack_api}/scoring_functions.py +9 -6
  349. {llama_stack/apis/shields → llama_stack_api}/shields.py +6 -7
  350. {llama_stack/apis/tools → llama_stack_api}/tools.py +26 -21
  351. {llama_stack/apis/vector_io → llama_stack_api}/vector_io.py +133 -152
  352. {llama_stack/apis/vector_stores → llama_stack_api}/vector_stores.py +1 -1
  353. llama_stack/apis/agents/agents.py +0 -894
  354. llama_stack/apis/batches/__init__.py +0 -9
  355. llama_stack/apis/batches/batches.py +0 -100
  356. llama_stack/apis/benchmarks/__init__.py +0 -7
  357. llama_stack/apis/benchmarks/benchmarks.py +0 -108
  358. llama_stack/apis/common/responses.py +0 -36
  359. llama_stack/apis/conversations/__init__.py +0 -31
  360. llama_stack/apis/datasets/datasets.py +0 -251
  361. llama_stack/apis/datatypes.py +0 -160
  362. llama_stack/apis/eval/__init__.py +0 -7
  363. llama_stack/apis/files/__init__.py +0 -7
  364. llama_stack/apis/files/files.py +0 -199
  365. llama_stack/apis/inference/__init__.py +0 -7
  366. llama_stack/apis/inference/event_logger.py +0 -43
  367. llama_stack/apis/inspect/__init__.py +0 -7
  368. llama_stack/apis/inspect/inspect.py +0 -94
  369. llama_stack/apis/models/__init__.py +0 -7
  370. llama_stack/apis/post_training/__init__.py +0 -7
  371. llama_stack/apis/prompts/__init__.py +0 -9
  372. llama_stack/apis/providers/__init__.py +0 -7
  373. llama_stack/apis/providers/providers.py +0 -69
  374. llama_stack/apis/safety/__init__.py +0 -7
  375. llama_stack/apis/scoring/__init__.py +0 -7
  376. llama_stack/apis/scoring_functions/__init__.py +0 -7
  377. llama_stack/apis/shields/__init__.py +0 -7
  378. llama_stack/apis/synthetic_data_generation/__init__.py +0 -7
  379. llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py +0 -77
  380. llama_stack/apis/telemetry/__init__.py +0 -7
  381. llama_stack/apis/telemetry/telemetry.py +0 -423
  382. llama_stack/apis/tools/__init__.py +0 -8
  383. llama_stack/apis/vector_io/__init__.py +0 -7
  384. llama_stack/apis/vector_stores/__init__.py +0 -7
  385. llama_stack/core/server/tracing.py +0 -80
  386. llama_stack/core/ui/app.py +0 -55
  387. llama_stack/core/ui/modules/__init__.py +0 -5
  388. llama_stack/core/ui/modules/api.py +0 -32
  389. llama_stack/core/ui/modules/utils.py +0 -42
  390. llama_stack/core/ui/page/__init__.py +0 -5
  391. llama_stack/core/ui/page/distribution/__init__.py +0 -5
  392. llama_stack/core/ui/page/distribution/datasets.py +0 -18
  393. llama_stack/core/ui/page/distribution/eval_tasks.py +0 -20
  394. llama_stack/core/ui/page/distribution/models.py +0 -18
  395. llama_stack/core/ui/page/distribution/providers.py +0 -27
  396. llama_stack/core/ui/page/distribution/resources.py +0 -48
  397. llama_stack/core/ui/page/distribution/scoring_functions.py +0 -18
  398. llama_stack/core/ui/page/distribution/shields.py +0 -19
  399. llama_stack/core/ui/page/evaluations/__init__.py +0 -5
  400. llama_stack/core/ui/page/evaluations/app_eval.py +0 -143
  401. llama_stack/core/ui/page/evaluations/native_eval.py +0 -253
  402. llama_stack/core/ui/page/playground/__init__.py +0 -5
  403. llama_stack/core/ui/page/playground/chat.py +0 -130
  404. llama_stack/core/ui/page/playground/tools.py +0 -352
  405. llama_stack/distributions/dell/build.yaml +0 -33
  406. llama_stack/distributions/meta-reference-gpu/build.yaml +0 -32
  407. llama_stack/distributions/nvidia/build.yaml +0 -29
  408. llama_stack/distributions/open-benchmark/build.yaml +0 -36
  409. llama_stack/distributions/postgres-demo/__init__.py +0 -7
  410. llama_stack/distributions/postgres-demo/build.yaml +0 -23
  411. llama_stack/distributions/postgres-demo/postgres_demo.py +0 -125
  412. llama_stack/distributions/starter/build.yaml +0 -61
  413. llama_stack/distributions/starter-gpu/build.yaml +0 -61
  414. llama_stack/distributions/watsonx/build.yaml +0 -33
  415. llama_stack/providers/inline/agents/meta_reference/agent_instance.py +0 -1024
  416. llama_stack/providers/inline/agents/meta_reference/persistence.py +0 -228
  417. llama_stack/providers/inline/telemetry/__init__.py +0 -5
  418. llama_stack/providers/inline/telemetry/meta_reference/__init__.py +0 -21
  419. llama_stack/providers/inline/telemetry/meta_reference/config.py +0 -47
  420. llama_stack/providers/inline/telemetry/meta_reference/telemetry.py +0 -252
  421. llama_stack/providers/remote/inference/bedrock/models.py +0 -29
  422. llama_stack/providers/utils/kvstore/sqlite/config.py +0 -20
  423. llama_stack/providers/utils/sqlstore/__init__.py +0 -5
  424. llama_stack/providers/utils/sqlstore/api.py +0 -128
  425. llama_stack/providers/utils/telemetry/__init__.py +0 -5
  426. llama_stack/providers/utils/telemetry/trace_protocol.py +0 -142
  427. llama_stack/providers/utils/telemetry/tracing.py +0 -384
  428. llama_stack/strong_typing/__init__.py +0 -19
  429. llama_stack/strong_typing/auxiliary.py +0 -228
  430. llama_stack/strong_typing/classdef.py +0 -440
  431. llama_stack/strong_typing/core.py +0 -46
  432. llama_stack/strong_typing/deserializer.py +0 -877
  433. llama_stack/strong_typing/docstring.py +0 -409
  434. llama_stack/strong_typing/exception.py +0 -23
  435. llama_stack/strong_typing/inspection.py +0 -1085
  436. llama_stack/strong_typing/mapping.py +0 -40
  437. llama_stack/strong_typing/name.py +0 -182
  438. llama_stack/strong_typing/schema.py +0 -792
  439. llama_stack/strong_typing/serialization.py +0 -97
  440. llama_stack/strong_typing/serializer.py +0 -500
  441. llama_stack/strong_typing/slots.py +0 -27
  442. llama_stack/strong_typing/topological.py +0 -89
  443. llama_stack/ui/node_modules/flatted/python/flatted.py +0 -149
  444. llama_stack-0.3.5.dist-info/RECORD +0 -625
  445. llama_stack-0.3.5.dist-info/top_level.txt +0 -1
  446. /llama_stack/{providers/utils → core/storage}/kvstore/config.py +0 -0
  447. /llama_stack/{providers/utils → core/storage}/kvstore/mongodb/__init__.py +0 -0
  448. /llama_stack/{providers/utils → core/storage}/kvstore/postgres/__init__.py +0 -0
  449. /llama_stack/{providers/utils → core/storage}/kvstore/redis/__init__.py +0 -0
  450. /llama_stack/{providers/utils → core/storage}/kvstore/sqlite/__init__.py +0 -0
  451. /llama_stack/{apis → providers/inline/file_processor}/__init__.py +0 -0
  452. /llama_stack/{apis/common → telemetry}/__init__.py +0 -0
  453. {llama_stack-0.3.5.dist-info → llama_stack-0.4.0.dist-info}/WHEEL +0 -0
  454. {llama_stack-0.3.5.dist-info → llama_stack-0.4.0.dist-info}/entry_points.txt +0 -0
  455. {llama_stack-0.3.5.dist-info → llama_stack-0.4.0.dist-info}/licenses/LICENSE +0 -0
  456. {llama_stack/core/ui → llama_stack_api/common}/__init__.py +0 -0
  457. {llama_stack/strong_typing → llama_stack_api}/py.typed +0 -0
  458. {llama_stack/apis → llama_stack_api}/version.py +0 -0
@@ -33,10 +33,6 @@ from torch.distributed.launcher.api import LaunchConfig, elastic_launch
33
33
 
34
34
  from llama_stack.log import get_logger
35
35
  from llama_stack.models.llama.datatypes import GenerationResult
36
- from llama_stack.providers.utils.inference.prompt_adapter import (
37
- ChatCompletionRequestWithRawContent,
38
- CompletionRequestWithRawContent,
39
- )
40
36
 
41
37
  log = get_logger(name=__name__, category="inference")
42
38
 
@@ -69,10 +65,7 @@ class CancelSentinel(BaseModel):
69
65
 
70
66
  class TaskRequest(BaseModel):
71
67
  type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
72
- task: tuple[
73
- str,
74
- list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent],
75
- ]
68
+ task: tuple[str, list]
76
69
 
77
70
 
78
71
  class TaskResponse(BaseModel):
@@ -328,10 +321,7 @@ class ModelParallelProcessGroup:
328
321
 
329
322
  def run_inference(
330
323
  self,
331
- req: tuple[
332
- str,
333
- list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent],
334
- ],
324
+ req: tuple[str, list],
335
325
  ) -> Generator:
336
326
  assert not self.running, "inference already running"
337
327
 
@@ -6,24 +6,20 @@
6
6
 
7
7
  from collections.abc import AsyncIterator
8
8
 
9
- from llama_stack.apis.inference import (
10
- InferenceProvider,
11
- OpenAIChatCompletionRequestWithExtraBody,
12
- OpenAICompletionRequestWithExtraBody,
13
- )
14
- from llama_stack.apis.inference.inference import (
15
- OpenAIChatCompletion,
16
- OpenAIChatCompletionChunk,
17
- OpenAICompletion,
18
- )
19
- from llama_stack.apis.models import ModelType
20
9
  from llama_stack.log import get_logger
21
- from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
22
10
  from llama_stack.providers.utils.inference.embedding_mixin import (
23
11
  SentenceTransformerEmbeddingMixin,
24
12
  )
25
- from llama_stack.providers.utils.inference.openai_compat import (
26
- OpenAIChatCompletionToLlamaStackMixin,
13
+ from llama_stack_api import (
14
+ InferenceProvider,
15
+ Model,
16
+ ModelsProtocolPrivate,
17
+ ModelType,
18
+ OpenAIChatCompletion,
19
+ OpenAIChatCompletionChunk,
20
+ OpenAIChatCompletionRequestWithExtraBody,
21
+ OpenAICompletion,
22
+ OpenAICompletionRequestWithExtraBody,
27
23
  )
28
24
 
29
25
  from .config import SentenceTransformersInferenceConfig
@@ -32,7 +28,6 @@ log = get_logger(name=__name__, category="inference")
32
28
 
33
29
 
34
30
  class SentenceTransformersInferenceImpl(
35
- OpenAIChatCompletionToLlamaStackMixin,
36
31
  SentenceTransformerEmbeddingMixin,
37
32
  InferenceProvider,
38
33
  ModelsProtocolPrivate,
@@ -12,14 +12,10 @@
12
12
 
13
13
  from typing import Any
14
14
 
15
- from llama_stack.apis.common.type_system import (
16
- ChatCompletionInputType,
17
- DialogType,
18
- StringType,
19
- )
20
15
  from llama_stack.providers.utils.common.data_schema_validator import (
21
16
  ColumnName,
22
17
  )
18
+ from llama_stack_api import ChatCompletionInputType, DialogType, StringType
23
19
 
24
20
  EXPECTED_DATASET_SCHEMA: dict[str, list[dict[str, Any]]] = {
25
21
  "instruct": [
@@ -6,11 +6,16 @@
6
6
  from enum import Enum
7
7
  from typing import Any
8
8
 
9
- from llama_stack.apis.datasetio import DatasetIO
10
- from llama_stack.apis.datasets import Datasets
11
- from llama_stack.apis.post_training import (
9
+ from llama_stack.providers.inline.post_training.huggingface.config import (
10
+ HuggingFacePostTrainingConfig,
11
+ )
12
+ from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
13
+ from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
14
+ from llama_stack_api import (
12
15
  AlgorithmConfig,
13
16
  Checkpoint,
17
+ DatasetIO,
18
+ Datasets,
14
19
  DPOAlignmentConfig,
15
20
  JobStatus,
16
21
  ListPostTrainingJobsResponse,
@@ -19,11 +24,6 @@ from llama_stack.apis.post_training import (
19
24
  PostTrainingJobStatusResponse,
20
25
  TrainingConfig,
21
26
  )
22
- from llama_stack.providers.inline.post_training.huggingface.config import (
23
- HuggingFacePostTrainingConfig,
24
- )
25
- from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
26
- from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
27
27
 
28
28
 
29
29
  class TrainingArtifactType(Enum):
@@ -14,24 +14,24 @@ import torch
14
14
  from datasets import Dataset
15
15
  from peft import LoraConfig
16
16
  from transformers import (
17
- AutoModelForCausalLM,
18
17
  AutoTokenizer,
19
18
  )
20
19
  from trl import SFTConfig, SFTTrainer
21
20
 
22
- from llama_stack.apis.datasetio import DatasetIO
23
- from llama_stack.apis.datasets import Datasets
24
- from llama_stack.apis.post_training import (
21
+ from llama_stack.log import get_logger
22
+ from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
23
+ from llama_stack_api import (
25
24
  Checkpoint,
26
25
  DataConfig,
26
+ DatasetIO,
27
+ Datasets,
27
28
  LoraFinetuningConfig,
28
29
  TrainingConfig,
29
30
  )
30
- from llama_stack.log import get_logger
31
- from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
32
31
 
33
32
  from ..config import HuggingFacePostTrainingConfig
34
33
  from ..utils import (
34
+ HFAutoModel,
35
35
  calculate_training_steps,
36
36
  create_checkpoints,
37
37
  get_memory_stats,
@@ -338,7 +338,7 @@ class HFFinetuningSingleDevice:
338
338
 
339
339
  def save_model(
340
340
  self,
341
- model_obj: AutoModelForCausalLM,
341
+ model_obj: HFAutoModel,
342
342
  trainer: SFTTrainer,
343
343
  peft_config: LoraConfig | None,
344
344
  output_dir_path: Path,
@@ -350,14 +350,22 @@ class HFFinetuningSingleDevice:
350
350
  peft_config: Optional LoRA configuration
351
351
  output_dir_path: Path to save the model
352
352
  """
353
+ from typing import cast
354
+
353
355
  logger.info("Saving final model")
354
356
  model_obj.config.use_cache = True
355
357
 
356
358
  if peft_config:
357
359
  logger.info("Merging LoRA weights with base model")
358
- model_obj = trainer.model.merge_and_unload()
360
+ # TRL's merge_and_unload returns a HuggingFace model
361
+ # Both cast() and type: ignore are needed here:
362
+ # - cast() tells mypy the return type is HFAutoModel for downstream code
363
+ # - type: ignore suppresses errors on the merge_and_unload() call itself,
364
+ # which mypy can't type-check due to TRL library's incomplete type stubs
365
+ model_obj = cast(HFAutoModel, trainer.model.merge_and_unload()) # type: ignore[union-attr,operator]
359
366
  else:
360
- model_obj = trainer.model
367
+ # trainer.model is the trained HuggingFace model
368
+ model_obj = cast(HFAutoModel, trainer.model)
361
369
 
362
370
  save_path = output_dir_path / "merged_model"
363
371
  logger.info(f"Saving model to {save_path}")
@@ -411,7 +419,7 @@ class HFFinetuningSingleDevice:
411
419
  # Initialize trainer
412
420
  logger.info("Initializing SFTTrainer")
413
421
  trainer = SFTTrainer(
414
- model=model_obj,
422
+ model=model_obj, # type: ignore[arg-type]
415
423
  train_dataset=train_dataset,
416
424
  eval_dataset=eval_dataset,
417
425
  peft_config=peft_config,
@@ -16,15 +16,15 @@ from transformers import (
16
16
  )
17
17
  from trl import DPOConfig, DPOTrainer
18
18
 
19
- from llama_stack.apis.datasetio import DatasetIO
20
- from llama_stack.apis.datasets import Datasets
21
- from llama_stack.apis.post_training import (
19
+ from llama_stack.log import get_logger
20
+ from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
21
+ from llama_stack_api import (
22
22
  Checkpoint,
23
+ DatasetIO,
24
+ Datasets,
23
25
  DPOAlignmentConfig,
24
26
  TrainingConfig,
25
27
  )
26
- from llama_stack.log import get_logger
27
- from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
28
28
 
29
29
  from ..config import HuggingFacePostTrainingConfig
30
30
  from ..utils import (
@@ -309,7 +309,7 @@ class HFDPOAlignmentSingleDevice:
309
309
  save_total_limit=provider_config.save_total_limit,
310
310
  # DPO specific parameters
311
311
  beta=dpo_config.beta,
312
- loss_type=provider_config.dpo_loss_type,
312
+ loss_type=provider_config.dpo_loss_type, # type: ignore[arg-type]
313
313
  )
314
314
 
315
315
  def save_model(
@@ -381,13 +381,16 @@ class HFDPOAlignmentSingleDevice:
381
381
 
382
382
  # Initialize DPO trainer
383
383
  logger.info("Initializing DPOTrainer")
384
+ # TRL library has incomplete type stubs - use Any to bypass
385
+ from typing import Any, cast
386
+
384
387
  trainer = DPOTrainer(
385
- model=model_obj,
386
- ref_model=ref_model,
388
+ model=cast(Any, model_obj), # HFAutoModel satisfies PreTrainedModel protocol
389
+ ref_model=cast(Any, ref_model),
387
390
  args=training_args,
388
391
  train_dataset=train_dataset,
389
392
  eval_dataset=eval_dataset,
390
- processing_class=tokenizer,
393
+ processing_class=cast(Any, tokenizer), # AutoTokenizer satisfies interface
391
394
  )
392
395
 
393
396
  try:
@@ -9,15 +9,33 @@ import signal
9
9
  import sys
10
10
  from datetime import UTC, datetime
11
11
  from pathlib import Path
12
- from typing import Any
12
+ from typing import TYPE_CHECKING, Any, Protocol
13
13
 
14
14
  import psutil
15
15
  import torch
16
16
  from datasets import Dataset
17
17
  from transformers import AutoConfig, AutoModelForCausalLM
18
18
 
19
- from llama_stack.apis.datasetio import DatasetIO
20
- from llama_stack.apis.post_training import Checkpoint, TrainingConfig
19
+ from llama_stack_api import Checkpoint, DatasetIO, TrainingConfig
20
+
21
+ if TYPE_CHECKING:
22
+ from transformers import PretrainedConfig
23
+
24
+
25
+ class HFAutoModel(Protocol):
26
+ """Protocol describing HuggingFace AutoModel interface.
27
+
28
+ This protocol defines the common interface for HuggingFace AutoModelForCausalLM
29
+ and similar models, providing type safety without requiring type stubs.
30
+ """
31
+
32
+ config: PretrainedConfig
33
+ device: torch.device
34
+
35
+ def to(self, device: torch.device) -> "HFAutoModel": ...
36
+ def save_pretrained(self, save_directory: str | Path) -> None: ...
37
+
38
+
21
39
  from llama_stack.log import get_logger
22
40
 
23
41
  from .config import HuggingFacePostTrainingConfig
@@ -132,7 +150,7 @@ def load_model(
132
150
  model: str,
133
151
  device: torch.device,
134
152
  provider_config: HuggingFacePostTrainingConfig,
135
- ) -> AutoModelForCausalLM:
153
+ ) -> HFAutoModel:
136
154
  """Load and initialize the model for training.
137
155
  Args:
138
156
  model: The model identifier to load
@@ -143,6 +161,8 @@ def load_model(
143
161
  Raises:
144
162
  RuntimeError: If model loading fails
145
163
  """
164
+ from typing import cast
165
+
146
166
  logger.info("Loading the base model")
147
167
  try:
148
168
  model_config = AutoConfig.from_pretrained(model, **provider_config.model_specific_config)
@@ -154,9 +174,10 @@ def load_model(
154
174
  **provider_config.model_specific_config,
155
175
  )
156
176
  # Always move model to specified device
157
- model_obj = model_obj.to(device)
177
+ model_obj = model_obj.to(device) # type: ignore[arg-type]
158
178
  logger.info(f"Model loaded and moved to device: {model_obj.device}")
159
- return model_obj
179
+ # Cast to HFAutoModel protocol - transformers models satisfy this interface
180
+ return cast(HFAutoModel, model_obj)
160
181
  except Exception as e:
161
182
  raise RuntimeError(f"Failed to load model: {str(e)}") from e
162
183
 
@@ -91,7 +91,7 @@ class TorchtuneCheckpointer:
91
91
  if checkpoint_format == "meta" or checkpoint_format is None:
92
92
  self._save_meta_format_checkpoint(model_file_path, state_dict, adapter_only)
93
93
  elif checkpoint_format == "huggingface":
94
- # Note: for saving hugging face format checkpoints, we only suppport saving adapter weights now
94
+ # Note: for saving hugging face format checkpoints, we only support saving adapter weights now
95
95
  self._save_hf_format_checkpoint(model_file_path, state_dict)
96
96
  else:
97
97
  raise ValueError(f"Unsupported checkpoint format: {format}")
@@ -21,9 +21,9 @@ from torchtune.models.llama3_1 import lora_llama3_1_8b
21
21
  from torchtune.models.llama3_2 import lora_llama3_2_3b
22
22
  from torchtune.modules.transforms import Transform
23
23
 
24
- from llama_stack.apis.post_training import DatasetFormat
25
24
  from llama_stack.models.llama.sku_list import resolve_model
26
25
  from llama_stack.models.llama.sku_types import Model
26
+ from llama_stack_api import DatasetFormat
27
27
 
28
28
  BuildLoraModelCallable = Callable[..., torch.nn.Module]
29
29
  BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
@@ -25,7 +25,7 @@ def llama_stack_instruct_to_torchtune_instruct(
25
25
  )
26
26
  input_messages = json.loads(sample[ColumnName.chat_completion_input.value])
27
27
 
28
- assert len(input_messages) == 1, "llama stack intruct dataset format only supports 1 user message"
28
+ assert len(input_messages) == 1, "llama stack instruct dataset format only supports 1 user message"
29
29
  input_message = input_messages[0]
30
30
 
31
31
  assert "content" in input_message, "content not found in input message"
@@ -6,11 +6,16 @@
6
6
  from enum import Enum
7
7
  from typing import Any
8
8
 
9
- from llama_stack.apis.datasetio import DatasetIO
10
- from llama_stack.apis.datasets import Datasets
11
- from llama_stack.apis.post_training import (
9
+ from llama_stack.providers.inline.post_training.torchtune.config import (
10
+ TorchtunePostTrainingConfig,
11
+ )
12
+ from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
13
+ from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
14
+ from llama_stack_api import (
12
15
  AlgorithmConfig,
13
16
  Checkpoint,
17
+ DatasetIO,
18
+ Datasets,
14
19
  DPOAlignmentConfig,
15
20
  JobStatus,
16
21
  ListPostTrainingJobsResponse,
@@ -20,11 +25,6 @@ from llama_stack.apis.post_training import (
20
25
  PostTrainingJobStatusResponse,
21
26
  TrainingConfig,
22
27
  )
23
- from llama_stack.providers.inline.post_training.torchtune.config import (
24
- TorchtunePostTrainingConfig,
25
- )
26
- from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
27
- from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
28
28
 
29
29
 
30
30
  class TrainingArtifactType(Enum):
@@ -32,17 +32,6 @@ from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup
32
32
  from torchtune.training.metric_logging import DiskLogger
33
33
  from tqdm import tqdm
34
34
 
35
- from llama_stack.apis.common.training_types import PostTrainingMetric
36
- from llama_stack.apis.datasetio import DatasetIO
37
- from llama_stack.apis.datasets import Datasets
38
- from llama_stack.apis.post_training import (
39
- Checkpoint,
40
- DataConfig,
41
- LoraFinetuningConfig,
42
- OptimizerConfig,
43
- QATFinetuningConfig,
44
- TrainingConfig,
45
- )
46
35
  from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
47
36
  from llama_stack.core.utils.model_utils import model_local_dir
48
37
  from llama_stack.log import get_logger
@@ -56,6 +45,17 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
56
45
  TorchtunePostTrainingConfig,
57
46
  )
58
47
  from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
48
+ from llama_stack_api import (
49
+ Checkpoint,
50
+ DataConfig,
51
+ DatasetIO,
52
+ Datasets,
53
+ LoraFinetuningConfig,
54
+ OptimizerConfig,
55
+ PostTrainingMetric,
56
+ QATFinetuningConfig,
57
+ TrainingConfig,
58
+ )
59
59
 
60
60
  log = get_logger(name=__name__, category="post_training")
61
61
 
@@ -193,7 +193,7 @@ class LoraFinetuningSingleDevice:
193
193
  log.info("Optimizer is initialized.")
194
194
 
195
195
  self._loss_fn = CEWithChunkedOutputLoss()
196
- self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
196
+ self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) # type: ignore[operator]
197
197
  log.info("Loss is initialized.")
198
198
 
199
199
  assert isinstance(self.training_config.data_config, DataConfig), "DataConfig must be initialized"
@@ -284,7 +284,7 @@ class LoraFinetuningSingleDevice:
284
284
  if self._is_dora:
285
285
  for m in model.modules():
286
286
  if hasattr(m, "initialize_dora_magnitude"):
287
- m.initialize_dora_magnitude()
287
+ m.initialize_dora_magnitude() # type: ignore[operator]
288
288
  if lora_weights_state_dict:
289
289
  lora_missing, lora_unexpected = model.load_state_dict(lora_weights_state_dict, strict=False)
290
290
  else:
@@ -353,7 +353,7 @@ class LoraFinetuningSingleDevice:
353
353
  dataset_type=self._data_format.value,
354
354
  )
355
355
 
356
- sampler = DistributedSampler(
356
+ sampler: DistributedSampler = DistributedSampler(
357
357
  ds,
358
358
  num_replicas=1,
359
359
  rank=0,
@@ -389,7 +389,7 @@ class LoraFinetuningSingleDevice:
389
389
  num_training_steps=num_training_steps,
390
390
  last_epoch=last_epoch,
391
391
  )
392
- return lr_scheduler
392
+ return lr_scheduler # type: ignore[no-any-return]
393
393
 
394
394
  async def save_checkpoint(self, epoch: int) -> str:
395
395
  ckpt_dict = {}
@@ -447,7 +447,7 @@ class LoraFinetuningSingleDevice:
447
447
  # free logits otherwise it peaks backward memory
448
448
  del logits
449
449
 
450
- return loss
450
+ return loss # type: ignore[no-any-return]
451
451
 
452
452
  async def train(self) -> tuple[dict[str, Any], list[Checkpoint]]:
453
453
  """
@@ -10,19 +10,20 @@ from typing import TYPE_CHECKING, Any
10
10
  if TYPE_CHECKING:
11
11
  from codeshield.cs import CodeShieldScanResult
12
12
 
13
- from llama_stack.apis.inference import OpenAIMessageParam
14
- from llama_stack.apis.safety import (
13
+ from llama_stack.log import get_logger
14
+ from llama_stack.providers.utils.inference.prompt_adapter import (
15
+ interleaved_content_as_str,
16
+ )
17
+ from llama_stack_api import (
18
+ ModerationObject,
19
+ ModerationObjectResults,
20
+ OpenAIMessageParam,
15
21
  RunShieldResponse,
16
22
  Safety,
17
23
  SafetyViolation,
24
+ Shield,
18
25
  ViolationLevel,
19
26
  )
20
- from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
21
- from llama_stack.apis.shields import Shield
22
- from llama_stack.log import get_logger
23
- from llama_stack.providers.utils.inference.prompt_adapter import (
24
- interleaved_content_as_str,
25
- )
26
27
 
27
28
  from .config import CodeScannerConfig
28
29
 
@@ -101,7 +102,10 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
101
102
  metadata=metadata,
102
103
  )
103
104
 
104
- async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
105
+ async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
106
+ if model is None:
107
+ raise ValueError("Code scanner moderation requires a model identifier.")
108
+
105
109
  inputs = input if isinstance(input, list) else [input]
106
110
  results = []
107
111
 
@@ -9,29 +9,29 @@ import uuid
9
9
  from string import Template
10
10
  from typing import Any
11
11
 
12
- from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
13
- from llama_stack.apis.inference import (
12
+ from llama_stack.core.datatypes import Api
13
+ from llama_stack.log import get_logger
14
+ from llama_stack.models.llama.datatypes import Role
15
+ from llama_stack.models.llama.sku_types import CoreModelId
16
+ from llama_stack.providers.utils.inference.prompt_adapter import (
17
+ interleaved_content_as_str,
18
+ )
19
+ from llama_stack_api import (
20
+ ImageContentItem,
14
21
  Inference,
22
+ ModerationObject,
23
+ ModerationObjectResults,
15
24
  OpenAIChatCompletionRequestWithExtraBody,
16
25
  OpenAIMessageParam,
17
26
  OpenAIUserMessageParam,
18
- )
19
- from llama_stack.apis.safety import (
20
27
  RunShieldResponse,
21
28
  Safety,
22
29
  SafetyViolation,
30
+ Shield,
31
+ ShieldsProtocolPrivate,
32
+ TextContentItem,
23
33
  ViolationLevel,
24
34
  )
25
- from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
26
- from llama_stack.apis.shields import Shield
27
- from llama_stack.core.datatypes import Api
28
- from llama_stack.log import get_logger
29
- from llama_stack.models.llama.datatypes import Role
30
- from llama_stack.models.llama.sku_types import CoreModelId
31
- from llama_stack.providers.datatypes import ShieldsProtocolPrivate
32
- from llama_stack.providers.utils.inference.prompt_adapter import (
33
- interleaved_content_as_str,
34
- )
35
35
 
36
36
  from .config import LlamaGuardConfig
37
37
 
@@ -200,7 +200,10 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
200
200
 
201
201
  return await impl.run(messages)
202
202
 
203
- async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
203
+ async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
204
+ if model is None:
205
+ raise ValueError("Llama Guard moderation requires a model identifier.")
206
+
204
207
  if isinstance(input, list):
205
208
  messages = input.copy()
206
209
  else:
@@ -9,20 +9,20 @@ from typing import Any
9
9
  import torch
10
10
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
11
11
 
12
- from llama_stack.apis.inference import OpenAIMessageParam
13
- from llama_stack.apis.safety import (
12
+ from llama_stack.core.utils.model_utils import model_local_dir
13
+ from llama_stack.log import get_logger
14
+ from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
15
+ from llama_stack_api import (
16
+ ModerationObject,
17
+ OpenAIMessageParam,
14
18
  RunShieldResponse,
15
19
  Safety,
16
20
  SafetyViolation,
21
+ Shield,
22
+ ShieldsProtocolPrivate,
17
23
  ShieldStore,
18
24
  ViolationLevel,
19
25
  )
20
- from llama_stack.apis.safety.safety import ModerationObject
21
- from llama_stack.apis.shields import Shield
22
- from llama_stack.core.utils.model_utils import model_local_dir
23
- from llama_stack.log import get_logger
24
- from llama_stack.providers.datatypes import ShieldsProtocolPrivate
25
- from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
26
26
 
27
27
  from .config import PromptGuardConfig, PromptGuardType
28
28
 
@@ -63,7 +63,7 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
63
63
 
64
64
  return await self.shield.run(messages)
65
65
 
66
- async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
66
+ async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
67
67
  raise NotImplementedError("run_moderation is not implemented for Prompt Guard")
68
68
 
69
69
 
@@ -5,21 +5,17 @@
5
5
  # the root directory of this source tree.
6
6
  from typing import Any
7
7
 
8
- from llama_stack.apis.datasetio import DatasetIO
9
- from llama_stack.apis.datasets import Datasets
10
- from llama_stack.apis.scoring import (
8
+ from llama_stack_api import (
9
+ DatasetIO,
10
+ Datasets,
11
11
  ScoreBatchResponse,
12
12
  ScoreResponse,
13
13
  Scoring,
14
+ ScoringFn,
15
+ ScoringFnParams,
16
+ ScoringFunctionsProtocolPrivate,
14
17
  ScoringResult,
15
18
  )
16
- from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
17
- from llama_stack.core.datatypes import Api
18
- from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
19
- from llama_stack.providers.utils.common.data_schema_validator import (
20
- get_valid_schemas,
21
- validate_dataset_schema,
22
- )
23
19
 
24
20
  from .config import BasicScoringConfig
25
21
  from .scoring_fn.docvqa_scoring_fn import DocVQAScoringFn
@@ -83,9 +79,6 @@ class BasicScoringImpl(
83
79
  scoring_functions: dict[str, ScoringFnParams | None] = None,
84
80
  save_results_dataset: bool = False,
85
81
  ) -> ScoreBatchResponse:
86
- dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
87
- validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
88
-
89
82
  all_rows = await self.datasetio_api.iterrows(
90
83
  dataset_id=dataset_id,
91
84
  limit=-1,
@@ -8,9 +8,8 @@ import json
8
8
  import re
9
9
  from typing import Any
10
10
 
11
- from llama_stack.apis.scoring import ScoringResultRow
12
- from llama_stack.apis.scoring_functions import ScoringFnParams
13
11
  from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
12
+ from llama_stack_api import ScoringFnParams, ScoringResultRow
14
13
 
15
14
  from .fn_defs.docvqa import docvqa
16
15
 
@@ -6,9 +6,8 @@
6
6
 
7
7
  from typing import Any
8
8
 
9
- from llama_stack.apis.scoring import ScoringResultRow
10
- from llama_stack.apis.scoring_functions import ScoringFnParams
11
9
  from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
10
+ from llama_stack_api import ScoringFnParams, ScoringResultRow
12
11
 
13
12
  from .fn_defs.equality import equality
14
13
 
@@ -4,10 +4,10 @@
4
4
  # This source code is licensed under the terms described in the LICENSE file in
5
5
  # the root directory of this source tree.
6
6
 
7
- from llama_stack.apis.common.type_system import NumberType
8
- from llama_stack.apis.scoring_functions import (
7
+ from llama_stack_api import (
9
8
  AggregationFunctionType,
10
9
  BasicScoringFnParams,
10
+ NumberType,
11
11
  ScoringFn,
12
12
  )
13
13