llama-stack 0.4.3__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (311) hide show
  1. llama_stack/cli/stack/_list_deps.py +11 -7
  2. llama_stack/cli/stack/run.py +3 -25
  3. llama_stack/core/access_control/datatypes.py +78 -0
  4. llama_stack/core/configure.py +2 -2
  5. {llama_stack_api/internal → llama_stack/core/connectors}/__init__.py +2 -2
  6. llama_stack/core/connectors/connectors.py +162 -0
  7. llama_stack/core/conversations/conversations.py +61 -58
  8. llama_stack/core/datatypes.py +54 -8
  9. llama_stack/core/library_client.py +60 -13
  10. llama_stack/core/prompts/prompts.py +43 -42
  11. llama_stack/core/routers/datasets.py +20 -17
  12. llama_stack/core/routers/eval_scoring.py +143 -53
  13. llama_stack/core/routers/inference.py +20 -9
  14. llama_stack/core/routers/safety.py +30 -42
  15. llama_stack/core/routers/vector_io.py +15 -7
  16. llama_stack/core/routing_tables/models.py +42 -3
  17. llama_stack/core/routing_tables/scoring_functions.py +19 -19
  18. llama_stack/core/routing_tables/shields.py +20 -17
  19. llama_stack/core/routing_tables/vector_stores.py +8 -5
  20. llama_stack/core/server/auth.py +192 -17
  21. llama_stack/core/server/fastapi_router_registry.py +40 -5
  22. llama_stack/core/server/server.py +24 -5
  23. llama_stack/core/stack.py +54 -10
  24. llama_stack/core/storage/datatypes.py +9 -0
  25. llama_stack/core/store/registry.py +1 -1
  26. llama_stack/core/utils/exec.py +2 -2
  27. llama_stack/core/utils/type_inspection.py +16 -2
  28. llama_stack/distributions/dell/config.yaml +4 -1
  29. llama_stack/distributions/dell/doc_template.md +209 -0
  30. llama_stack/distributions/dell/run-with-safety.yaml +4 -1
  31. llama_stack/distributions/nvidia/config.yaml +4 -1
  32. llama_stack/distributions/nvidia/doc_template.md +170 -0
  33. llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
  34. llama_stack/distributions/oci/config.yaml +4 -1
  35. llama_stack/distributions/oci/doc_template.md +140 -0
  36. llama_stack/distributions/open-benchmark/config.yaml +9 -1
  37. llama_stack/distributions/postgres-demo/config.yaml +1 -1
  38. llama_stack/distributions/starter/build.yaml +62 -0
  39. llama_stack/distributions/starter/config.yaml +22 -3
  40. llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
  41. llama_stack/distributions/starter/starter.py +13 -1
  42. llama_stack/distributions/starter-gpu/build.yaml +62 -0
  43. llama_stack/distributions/starter-gpu/config.yaml +22 -3
  44. llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
  45. llama_stack/distributions/template.py +10 -2
  46. llama_stack/distributions/watsonx/config.yaml +4 -1
  47. llama_stack/log.py +1 -0
  48. llama_stack/models/llama/resources/dog.jpg +0 -0
  49. llama_stack/models/llama/resources/pasta.jpeg +0 -0
  50. llama_stack/models/llama/resources/small_dog.jpg +0 -0
  51. llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
  52. llama_stack/providers/inline/agents/meta_reference/agents.py +58 -61
  53. llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +187 -60
  54. llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +99 -22
  55. llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
  56. llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
  57. llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
  58. llama_stack/providers/inline/batches/reference/batches.py +2 -1
  59. llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
  60. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.h +9 -0
  61. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.swift +189 -0
  62. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/Parsing.swift +238 -0
  63. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/PromptTemplate.swift +12 -0
  64. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/SystemPrompts.swift +89 -0
  65. llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.pbxproj +550 -0
  66. llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/contents.xcworkspacedata +7 -0
  67. llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +8 -0
  68. llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
  69. llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
  70. llama_stack/providers/inline/post_training/torchtune/common/utils.py +5 -9
  71. llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
  72. llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
  73. llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
  74. llama_stack/providers/inline/safety/llama_guard/llama_guard.py +20 -24
  75. llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
  76. llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
  77. llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
  78. llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
  79. llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
  80. llama_stack/providers/registry/agents.py +1 -0
  81. llama_stack/providers/registry/inference.py +1 -9
  82. llama_stack/providers/registry/vector_io.py +136 -16
  83. llama_stack/providers/remote/datasetio/nvidia/README.md +74 -0
  84. llama_stack/providers/remote/eval/nvidia/README.md +134 -0
  85. llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
  86. llama_stack/providers/remote/files/s3/README.md +266 -0
  87. llama_stack/providers/remote/files/s3/config.py +5 -3
  88. llama_stack/providers/remote/files/s3/files.py +2 -2
  89. llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
  90. llama_stack/providers/remote/inference/nvidia/NVIDIA.md +203 -0
  91. llama_stack/providers/remote/inference/openai/openai.py +2 -0
  92. llama_stack/providers/remote/inference/together/together.py +4 -0
  93. llama_stack/providers/remote/inference/vertexai/config.py +3 -3
  94. llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
  95. llama_stack/providers/remote/inference/vllm/config.py +37 -18
  96. llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
  97. llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
  98. llama_stack/providers/remote/post_training/nvidia/README.md +151 -0
  99. llama_stack/providers/remote/post_training/nvidia/models.py +3 -11
  100. llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
  101. llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
  102. llama_stack/providers/remote/safety/nvidia/README.md +78 -0
  103. llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
  104. llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
  105. llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
  106. llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
  107. llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
  108. llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
  109. llama_stack/providers/remote/vector_io/oci/config.py +41 -0
  110. llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
  111. llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
  112. llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
  113. llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
  114. llama_stack/providers/utils/bedrock/client.py +3 -3
  115. llama_stack/providers/utils/bedrock/config.py +7 -7
  116. llama_stack/providers/utils/inference/__init__.py +0 -25
  117. llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
  118. llama_stack/providers/utils/inference/http_client.py +239 -0
  119. llama_stack/providers/utils/inference/litellm_openai_mixin.py +6 -0
  120. llama_stack/providers/utils/inference/model_registry.py +148 -2
  121. llama_stack/providers/utils/inference/openai_compat.py +1 -158
  122. llama_stack/providers/utils/inference/openai_mixin.py +42 -2
  123. llama_stack/providers/utils/inference/prompt_adapter.py +0 -209
  124. llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
  125. llama_stack/providers/utils/memory/vector_store.py +46 -19
  126. llama_stack/providers/utils/responses/responses_store.py +40 -6
  127. llama_stack/providers/utils/safety.py +114 -0
  128. llama_stack/providers/utils/tools/mcp.py +44 -3
  129. llama_stack/testing/api_recorder.py +9 -3
  130. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/METADATA +14 -2
  131. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/RECORD +135 -279
  132. llama_stack-0.5.0.dist-info/top_level.txt +1 -0
  133. llama_stack/distributions/meta-reference-gpu/__init__.py +0 -7
  134. llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
  135. llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
  136. llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
  137. llama_stack/models/llama/hadamard_utils.py +0 -88
  138. llama_stack/models/llama/llama3/args.py +0 -74
  139. llama_stack/models/llama/llama3/generation.py +0 -378
  140. llama_stack/models/llama/llama3/model.py +0 -304
  141. llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
  142. llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
  143. llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
  144. llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
  145. llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
  146. llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
  147. llama_stack/models/llama/llama3/quantization/loader.py +0 -316
  148. llama_stack/models/llama/llama3_1/__init__.py +0 -12
  149. llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
  150. llama_stack/models/llama/llama3_1/prompts.py +0 -258
  151. llama_stack/models/llama/llama3_2/__init__.py +0 -5
  152. llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
  153. llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
  154. llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
  155. llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
  156. llama_stack/models/llama/llama3_3/__init__.py +0 -5
  157. llama_stack/models/llama/llama3_3/prompts.py +0 -259
  158. llama_stack/models/llama/llama4/args.py +0 -107
  159. llama_stack/models/llama/llama4/ffn.py +0 -58
  160. llama_stack/models/llama/llama4/moe.py +0 -214
  161. llama_stack/models/llama/llama4/preprocess.py +0 -435
  162. llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
  163. llama_stack/models/llama/llama4/quantization/loader.py +0 -226
  164. llama_stack/models/llama/llama4/vision/__init__.py +0 -5
  165. llama_stack/models/llama/llama4/vision/embedding.py +0 -210
  166. llama_stack/models/llama/llama4/vision/encoder.py +0 -412
  167. llama_stack/models/llama/quantize_impls.py +0 -316
  168. llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
  169. llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
  170. llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
  171. llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
  172. llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
  173. llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
  174. llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
  175. llama_stack-0.4.3.dist-info/top_level.txt +0 -2
  176. llama_stack_api/__init__.py +0 -945
  177. llama_stack_api/admin/__init__.py +0 -45
  178. llama_stack_api/admin/api.py +0 -72
  179. llama_stack_api/admin/fastapi_routes.py +0 -117
  180. llama_stack_api/admin/models.py +0 -113
  181. llama_stack_api/agents.py +0 -173
  182. llama_stack_api/batches/__init__.py +0 -40
  183. llama_stack_api/batches/api.py +0 -53
  184. llama_stack_api/batches/fastapi_routes.py +0 -113
  185. llama_stack_api/batches/models.py +0 -78
  186. llama_stack_api/benchmarks/__init__.py +0 -43
  187. llama_stack_api/benchmarks/api.py +0 -39
  188. llama_stack_api/benchmarks/fastapi_routes.py +0 -109
  189. llama_stack_api/benchmarks/models.py +0 -109
  190. llama_stack_api/common/__init__.py +0 -5
  191. llama_stack_api/common/content_types.py +0 -101
  192. llama_stack_api/common/errors.py +0 -95
  193. llama_stack_api/common/job_types.py +0 -38
  194. llama_stack_api/common/responses.py +0 -77
  195. llama_stack_api/common/training_types.py +0 -47
  196. llama_stack_api/common/type_system.py +0 -146
  197. llama_stack_api/connectors.py +0 -146
  198. llama_stack_api/conversations.py +0 -270
  199. llama_stack_api/datasetio.py +0 -55
  200. llama_stack_api/datasets/__init__.py +0 -61
  201. llama_stack_api/datasets/api.py +0 -35
  202. llama_stack_api/datasets/fastapi_routes.py +0 -104
  203. llama_stack_api/datasets/models.py +0 -152
  204. llama_stack_api/datatypes.py +0 -373
  205. llama_stack_api/eval.py +0 -137
  206. llama_stack_api/file_processors/__init__.py +0 -27
  207. llama_stack_api/file_processors/api.py +0 -64
  208. llama_stack_api/file_processors/fastapi_routes.py +0 -78
  209. llama_stack_api/file_processors/models.py +0 -42
  210. llama_stack_api/files/__init__.py +0 -35
  211. llama_stack_api/files/api.py +0 -51
  212. llama_stack_api/files/fastapi_routes.py +0 -124
  213. llama_stack_api/files/models.py +0 -107
  214. llama_stack_api/inference.py +0 -1169
  215. llama_stack_api/inspect_api/__init__.py +0 -37
  216. llama_stack_api/inspect_api/api.py +0 -25
  217. llama_stack_api/inspect_api/fastapi_routes.py +0 -76
  218. llama_stack_api/inspect_api/models.py +0 -28
  219. llama_stack_api/internal/kvstore.py +0 -28
  220. llama_stack_api/internal/sqlstore.py +0 -81
  221. llama_stack_api/llama_stack_api/__init__.py +0 -945
  222. llama_stack_api/llama_stack_api/admin/__init__.py +0 -45
  223. llama_stack_api/llama_stack_api/admin/api.py +0 -72
  224. llama_stack_api/llama_stack_api/admin/fastapi_routes.py +0 -117
  225. llama_stack_api/llama_stack_api/admin/models.py +0 -113
  226. llama_stack_api/llama_stack_api/agents.py +0 -173
  227. llama_stack_api/llama_stack_api/batches/__init__.py +0 -40
  228. llama_stack_api/llama_stack_api/batches/api.py +0 -53
  229. llama_stack_api/llama_stack_api/batches/fastapi_routes.py +0 -113
  230. llama_stack_api/llama_stack_api/batches/models.py +0 -78
  231. llama_stack_api/llama_stack_api/benchmarks/__init__.py +0 -43
  232. llama_stack_api/llama_stack_api/benchmarks/api.py +0 -39
  233. llama_stack_api/llama_stack_api/benchmarks/fastapi_routes.py +0 -109
  234. llama_stack_api/llama_stack_api/benchmarks/models.py +0 -109
  235. llama_stack_api/llama_stack_api/common/__init__.py +0 -5
  236. llama_stack_api/llama_stack_api/common/content_types.py +0 -101
  237. llama_stack_api/llama_stack_api/common/errors.py +0 -95
  238. llama_stack_api/llama_stack_api/common/job_types.py +0 -38
  239. llama_stack_api/llama_stack_api/common/responses.py +0 -77
  240. llama_stack_api/llama_stack_api/common/training_types.py +0 -47
  241. llama_stack_api/llama_stack_api/common/type_system.py +0 -146
  242. llama_stack_api/llama_stack_api/connectors.py +0 -146
  243. llama_stack_api/llama_stack_api/conversations.py +0 -270
  244. llama_stack_api/llama_stack_api/datasetio.py +0 -55
  245. llama_stack_api/llama_stack_api/datasets/__init__.py +0 -61
  246. llama_stack_api/llama_stack_api/datasets/api.py +0 -35
  247. llama_stack_api/llama_stack_api/datasets/fastapi_routes.py +0 -104
  248. llama_stack_api/llama_stack_api/datasets/models.py +0 -152
  249. llama_stack_api/llama_stack_api/datatypes.py +0 -373
  250. llama_stack_api/llama_stack_api/eval.py +0 -137
  251. llama_stack_api/llama_stack_api/file_processors/__init__.py +0 -27
  252. llama_stack_api/llama_stack_api/file_processors/api.py +0 -64
  253. llama_stack_api/llama_stack_api/file_processors/fastapi_routes.py +0 -78
  254. llama_stack_api/llama_stack_api/file_processors/models.py +0 -42
  255. llama_stack_api/llama_stack_api/files/__init__.py +0 -35
  256. llama_stack_api/llama_stack_api/files/api.py +0 -51
  257. llama_stack_api/llama_stack_api/files/fastapi_routes.py +0 -124
  258. llama_stack_api/llama_stack_api/files/models.py +0 -107
  259. llama_stack_api/llama_stack_api/inference.py +0 -1169
  260. llama_stack_api/llama_stack_api/inspect_api/__init__.py +0 -37
  261. llama_stack_api/llama_stack_api/inspect_api/api.py +0 -25
  262. llama_stack_api/llama_stack_api/inspect_api/fastapi_routes.py +0 -76
  263. llama_stack_api/llama_stack_api/inspect_api/models.py +0 -28
  264. llama_stack_api/llama_stack_api/internal/__init__.py +0 -9
  265. llama_stack_api/llama_stack_api/internal/kvstore.py +0 -28
  266. llama_stack_api/llama_stack_api/internal/sqlstore.py +0 -81
  267. llama_stack_api/llama_stack_api/models.py +0 -171
  268. llama_stack_api/llama_stack_api/openai_responses.py +0 -1468
  269. llama_stack_api/llama_stack_api/post_training.py +0 -370
  270. llama_stack_api/llama_stack_api/prompts.py +0 -203
  271. llama_stack_api/llama_stack_api/providers/__init__.py +0 -33
  272. llama_stack_api/llama_stack_api/providers/api.py +0 -16
  273. llama_stack_api/llama_stack_api/providers/fastapi_routes.py +0 -57
  274. llama_stack_api/llama_stack_api/providers/models.py +0 -24
  275. llama_stack_api/llama_stack_api/py.typed +0 -0
  276. llama_stack_api/llama_stack_api/rag_tool.py +0 -168
  277. llama_stack_api/llama_stack_api/resource.py +0 -37
  278. llama_stack_api/llama_stack_api/router_utils.py +0 -160
  279. llama_stack_api/llama_stack_api/safety.py +0 -132
  280. llama_stack_api/llama_stack_api/schema_utils.py +0 -208
  281. llama_stack_api/llama_stack_api/scoring.py +0 -93
  282. llama_stack_api/llama_stack_api/scoring_functions.py +0 -211
  283. llama_stack_api/llama_stack_api/shields.py +0 -93
  284. llama_stack_api/llama_stack_api/tools.py +0 -226
  285. llama_stack_api/llama_stack_api/vector_io.py +0 -941
  286. llama_stack_api/llama_stack_api/vector_stores.py +0 -53
  287. llama_stack_api/llama_stack_api/version.py +0 -9
  288. llama_stack_api/models.py +0 -171
  289. llama_stack_api/openai_responses.py +0 -1468
  290. llama_stack_api/post_training.py +0 -370
  291. llama_stack_api/prompts.py +0 -203
  292. llama_stack_api/providers/__init__.py +0 -33
  293. llama_stack_api/providers/api.py +0 -16
  294. llama_stack_api/providers/fastapi_routes.py +0 -57
  295. llama_stack_api/providers/models.py +0 -24
  296. llama_stack_api/py.typed +0 -0
  297. llama_stack_api/rag_tool.py +0 -168
  298. llama_stack_api/resource.py +0 -37
  299. llama_stack_api/router_utils.py +0 -160
  300. llama_stack_api/safety.py +0 -132
  301. llama_stack_api/schema_utils.py +0 -208
  302. llama_stack_api/scoring.py +0 -93
  303. llama_stack_api/scoring_functions.py +0 -211
  304. llama_stack_api/shields.py +0 -93
  305. llama_stack_api/tools.py +0 -226
  306. llama_stack_api/vector_io.py +0 -941
  307. llama_stack_api/vector_stores.py +0 -53
  308. llama_stack_api/version.py +0 -9
  309. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/WHEEL +0 -0
  310. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/entry_points.txt +0 -0
  311. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -12,17 +12,19 @@ from llama_stack.providers.inline.post_training.huggingface.config import (
12
12
  from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
13
13
  from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
14
14
  from llama_stack_api import (
15
- AlgorithmConfig,
15
+ CancelTrainingJobRequest,
16
16
  Checkpoint,
17
17
  DatasetIO,
18
18
  Datasets,
19
- DPOAlignmentConfig,
19
+ GetTrainingJobArtifactsRequest,
20
+ GetTrainingJobStatusRequest,
20
21
  JobStatus,
21
22
  ListPostTrainingJobsResponse,
22
23
  PostTrainingJob,
23
24
  PostTrainingJobArtifactsResponse,
24
25
  PostTrainingJobStatusResponse,
25
- TrainingConfig,
26
+ PreferenceOptimizeRequest,
27
+ SupervisedFineTuneRequest,
26
28
  )
27
29
 
28
30
 
@@ -69,13 +71,7 @@ class HuggingFacePostTrainingImpl:
69
71
 
70
72
  async def supervised_fine_tune(
71
73
  self,
72
- job_uuid: str,
73
- training_config: TrainingConfig,
74
- hyperparam_search_config: dict[str, Any],
75
- logger_config: dict[str, Any],
76
- model: str,
77
- checkpoint_dir: str | None = None,
78
- algorithm_config: AlgorithmConfig | None = None,
74
+ request: SupervisedFineTuneRequest,
79
75
  ) -> PostTrainingJob:
80
76
  async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
81
77
  from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import (
@@ -85,17 +81,17 @@ class HuggingFacePostTrainingImpl:
85
81
  on_log_message_cb("Starting HF finetuning")
86
82
 
87
83
  recipe = HFFinetuningSingleDevice(
88
- job_uuid=job_uuid,
84
+ job_uuid=request.job_uuid,
89
85
  datasetio_api=self.datasetio_api,
90
86
  datasets_api=self.datasets_api,
91
87
  )
92
88
 
93
89
  resources_allocated, checkpoints = await recipe.train(
94
- model=model,
95
- output_dir=checkpoint_dir,
96
- job_uuid=job_uuid,
97
- lora_config=algorithm_config,
98
- config=training_config,
90
+ model=request.model,
91
+ output_dir=request.checkpoint_dir,
92
+ job_uuid=request.job_uuid,
93
+ lora_config=request.algorithm_config,
94
+ config=request.training_config,
99
95
  provider_config=self.config,
100
96
  )
101
97
 
@@ -108,17 +104,12 @@ class HuggingFacePostTrainingImpl:
108
104
  on_status_change_cb(SchedulerJobStatus.completed)
109
105
  on_log_message_cb("HF finetuning completed")
110
106
 
111
- job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
107
+ job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, request.job_uuid, handler)
112
108
  return PostTrainingJob(job_uuid=job_uuid)
113
109
 
114
110
  async def preference_optimize(
115
111
  self,
116
- job_uuid: str,
117
- finetuned_model: str,
118
- algorithm_config: DPOAlignmentConfig,
119
- training_config: TrainingConfig,
120
- hyperparam_search_config: dict[str, Any],
121
- logger_config: dict[str, Any],
112
+ request: PreferenceOptimizeRequest,
122
113
  ) -> PostTrainingJob:
123
114
  async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
124
115
  from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device_dpo import (
@@ -128,17 +119,17 @@ class HuggingFacePostTrainingImpl:
128
119
  on_log_message_cb("Starting HF DPO alignment")
129
120
 
130
121
  recipe = HFDPOAlignmentSingleDevice(
131
- job_uuid=job_uuid,
122
+ job_uuid=request.job_uuid,
132
123
  datasetio_api=self.datasetio_api,
133
124
  datasets_api=self.datasets_api,
134
125
  )
135
126
 
136
127
  resources_allocated, checkpoints = await recipe.train(
137
- model=finetuned_model,
138
- output_dir=f"{self.config.dpo_output_dir}/{job_uuid}",
139
- job_uuid=job_uuid,
140
- dpo_config=algorithm_config,
141
- config=training_config,
128
+ model=request.finetuned_model,
129
+ output_dir=f"{self.config.dpo_output_dir}/{request.job_uuid}",
130
+ job_uuid=request.job_uuid,
131
+ dpo_config=request.algorithm_config,
132
+ config=request.training_config,
142
133
  provider_config=self.config,
143
134
  )
144
135
 
@@ -153,7 +144,7 @@ class HuggingFacePostTrainingImpl:
153
144
  on_status_change_cb(SchedulerJobStatus.completed)
154
145
  on_log_message_cb("HF DPO alignment completed")
155
146
 
156
- job_uuid = self._scheduler.schedule(_JOB_TYPE_DPO_TRAINING, job_uuid, handler)
147
+ job_uuid = self._scheduler.schedule(_JOB_TYPE_DPO_TRAINING, request.job_uuid, handler)
157
148
  return PostTrainingJob(job_uuid=job_uuid)
158
149
 
159
150
  @staticmethod
@@ -169,8 +160,10 @@ class HuggingFacePostTrainingImpl:
169
160
  data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
170
161
  return data[0] if data else None
171
162
 
172
- async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse | None:
173
- job = self._scheduler.get_job(job_uuid)
163
+ async def get_training_job_status(
164
+ self, request: GetTrainingJobStatusRequest
165
+ ) -> PostTrainingJobStatusResponse | None:
166
+ job = self._scheduler.get_job(request.job_uuid)
174
167
 
175
168
  match job.status:
176
169
  # TODO: Add support for other statuses to API
@@ -186,7 +179,7 @@ class HuggingFacePostTrainingImpl:
186
179
  raise NotImplementedError()
187
180
 
188
181
  return PostTrainingJobStatusResponse(
189
- job_uuid=job_uuid,
182
+ job_uuid=request.job_uuid,
190
183
  status=status,
191
184
  scheduled_at=job.scheduled_at,
192
185
  started_at=job.started_at,
@@ -195,12 +188,14 @@ class HuggingFacePostTrainingImpl:
195
188
  resources_allocated=self._get_resources_allocated(job),
196
189
  )
197
190
 
198
- async def cancel_training_job(self, job_uuid: str) -> None:
199
- self._scheduler.cancel(job_uuid)
191
+ async def cancel_training_job(self, request: CancelTrainingJobRequest) -> None:
192
+ self._scheduler.cancel(request.job_uuid)
200
193
 
201
- async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse | None:
202
- job = self._scheduler.get_job(job_uuid)
203
- return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))
194
+ async def get_training_job_artifacts(
195
+ self, request: GetTrainingJobArtifactsRequest
196
+ ) -> PostTrainingJobArtifactsResponse | None:
197
+ job = self._scheduler.get_job(request.job_uuid)
198
+ return PostTrainingJobArtifactsResponse(job_uuid=request.job_uuid, checkpoints=self._get_checkpoints(job))
204
199
 
205
200
  async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
206
201
  return ListPostTrainingJobsResponse(
@@ -16,7 +16,7 @@ import torch
16
16
  from datasets import Dataset
17
17
  from transformers import AutoConfig, AutoModelForCausalLM
18
18
 
19
- from llama_stack_api import Checkpoint, DatasetIO, TrainingConfig
19
+ from llama_stack_api import Checkpoint, DatasetIO, IterRowsRequest, TrainingConfig
20
20
 
21
21
  if TYPE_CHECKING:
22
22
  from transformers import PretrainedConfig
@@ -135,10 +135,7 @@ def setup_torch_device(device_str: str) -> torch.device:
135
135
  async def load_rows_from_dataset(datasetio_api: DatasetIO, dataset_id: str) -> list[dict[str, Any]]:
136
136
  """Load dataset from llama stack dataset provider"""
137
137
  try:
138
- all_rows = await datasetio_api.iterrows(
139
- dataset_id=dataset_id,
140
- limit=-1,
141
- )
138
+ all_rows = await datasetio_api.iterrows(IterRowsRequest(dataset_id=dataset_id, limit=-1))
142
139
  if not isinstance(all_rows.data, list):
143
140
  raise RuntimeError("Expected dataset data to be a list")
144
141
  return all_rows.data
@@ -22,7 +22,6 @@ from torchtune.models.llama3_2 import lora_llama3_2_3b
22
22
  from torchtune.modules.transforms import Transform
23
23
 
24
24
  from llama_stack.models.llama.sku_list import resolve_model
25
- from llama_stack.models.llama.sku_types import Model
26
25
  from llama_stack_api import DatasetFormat
27
26
 
28
27
  BuildLoraModelCallable = Callable[..., torch.nn.Module]
@@ -54,18 +53,17 @@ DATA_FORMATS: dict[str, Transform] = {
54
53
  }
55
54
 
56
55
 
57
- def _validate_model_id(model_id: str) -> Model:
56
+ def _validate_model_id(model_id: str) -> str:
58
57
  model = resolve_model(model_id)
59
58
  if model is None or model.core_model_id.value not in MODEL_CONFIGS:
60
59
  raise ValueError(f"Model {model_id} is not supported.")
61
- return model
60
+ return model.core_model_id.value
62
61
 
63
62
 
64
63
  async def get_model_definition(
65
64
  model_id: str,
66
65
  ) -> BuildLoraModelCallable:
67
- model = _validate_model_id(model_id)
68
- model_config = MODEL_CONFIGS[model.core_model_id.value]
66
+ model_config = MODEL_CONFIGS[_validate_model_id(model_id)]
69
67
  if not hasattr(model_config, "model_definition"):
70
68
  raise ValueError(f"Model {model_id} does not have model definition.")
71
69
  return model_config.model_definition
@@ -74,8 +72,7 @@ async def get_model_definition(
74
72
  async def get_tokenizer_type(
75
73
  model_id: str,
76
74
  ) -> BuildTokenizerCallable:
77
- model = _validate_model_id(model_id)
78
- model_config = MODEL_CONFIGS[model.core_model_id.value]
75
+ model_config = MODEL_CONFIGS[_validate_model_id(model_id)]
79
76
  if not hasattr(model_config, "tokenizer_type"):
80
77
  raise ValueError(f"Model {model_id} does not have tokenizer_type.")
81
78
  return model_config.tokenizer_type
@@ -88,8 +85,7 @@ async def get_checkpointer_model_type(
88
85
  checkpointer model type is used in checkpointer for some special treatment on some specific model types
89
86
  For example, llama3.2 model tied weights (https://github.com/pytorch/torchtune/blob/main/torchtune/training/checkpointing/_checkpointer.py#L1041)
90
87
  """
91
- model = _validate_model_id(model_id)
92
- model_config = MODEL_CONFIGS[model.core_model_id.value]
88
+ model_config = MODEL_CONFIGS[_validate_model_id(model_id)]
93
89
  if not hasattr(model_config, "checkpoint_type"):
94
90
  raise ValueError(f"Model {model_id} does not have checkpoint_type.")
95
91
  return model_config.checkpoint_type
@@ -12,18 +12,20 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
12
12
  from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
13
13
  from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
14
14
  from llama_stack_api import (
15
- AlgorithmConfig,
15
+ CancelTrainingJobRequest,
16
16
  Checkpoint,
17
17
  DatasetIO,
18
18
  Datasets,
19
- DPOAlignmentConfig,
19
+ GetTrainingJobArtifactsRequest,
20
+ GetTrainingJobStatusRequest,
20
21
  JobStatus,
21
22
  ListPostTrainingJobsResponse,
22
23
  LoraFinetuningConfig,
23
24
  PostTrainingJob,
24
25
  PostTrainingJobArtifactsResponse,
25
26
  PostTrainingJobStatusResponse,
26
- TrainingConfig,
27
+ PreferenceOptimizeRequest,
28
+ SupervisedFineTuneRequest,
27
29
  )
28
30
 
29
31
 
@@ -69,15 +71,9 @@ class TorchtunePostTrainingImpl:
69
71
 
70
72
  async def supervised_fine_tune(
71
73
  self,
72
- job_uuid: str,
73
- training_config: TrainingConfig,
74
- hyperparam_search_config: dict[str, Any],
75
- logger_config: dict[str, Any],
76
- model: str,
77
- checkpoint_dir: str | None,
78
- algorithm_config: AlgorithmConfig | None,
74
+ request: SupervisedFineTuneRequest,
79
75
  ) -> PostTrainingJob:
80
- if isinstance(algorithm_config, LoraFinetuningConfig):
76
+ if isinstance(request.algorithm_config, LoraFinetuningConfig):
81
77
 
82
78
  async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
83
79
  from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
@@ -88,13 +84,13 @@ class TorchtunePostTrainingImpl:
88
84
 
89
85
  recipe = LoraFinetuningSingleDevice(
90
86
  self.config,
91
- job_uuid,
92
- training_config,
93
- hyperparam_search_config,
94
- logger_config,
95
- model,
96
- checkpoint_dir,
97
- algorithm_config,
87
+ request.job_uuid,
88
+ request.training_config,
89
+ request.hyperparam_search_config,
90
+ request.logger_config,
91
+ request.model,
92
+ request.checkpoint_dir,
93
+ request.algorithm_config,
98
94
  self.datasetio_api,
99
95
  self.datasets_api,
100
96
  )
@@ -112,17 +108,12 @@ class TorchtunePostTrainingImpl:
112
108
  else:
113
109
  raise NotImplementedError()
114
110
 
115
- job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
111
+ job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, request.job_uuid, handler)
116
112
  return PostTrainingJob(job_uuid=job_uuid)
117
113
 
118
114
  async def preference_optimize(
119
115
  self,
120
- job_uuid: str,
121
- finetuned_model: str,
122
- algorithm_config: DPOAlignmentConfig,
123
- training_config: TrainingConfig,
124
- hyperparam_search_config: dict[str, Any],
125
- logger_config: dict[str, Any],
116
+ request: PreferenceOptimizeRequest,
126
117
  ) -> PostTrainingJob:
127
118
  raise NotImplementedError()
128
119
 
@@ -144,8 +135,10 @@ class TorchtunePostTrainingImpl:
144
135
  data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
145
136
  return data[0] if data else None
146
137
 
147
- async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse | None:
148
- job = self._scheduler.get_job(job_uuid)
138
+ async def get_training_job_status(
139
+ self, request: GetTrainingJobStatusRequest
140
+ ) -> PostTrainingJobStatusResponse | None:
141
+ job = self._scheduler.get_job(request.job_uuid)
149
142
 
150
143
  match job.status:
151
144
  # TODO: Add support for other statuses to API
@@ -161,7 +154,7 @@ class TorchtunePostTrainingImpl:
161
154
  raise NotImplementedError()
162
155
 
163
156
  return PostTrainingJobStatusResponse(
164
- job_uuid=job_uuid,
157
+ job_uuid=request.job_uuid,
165
158
  status=status,
166
159
  scheduled_at=job.scheduled_at,
167
160
  started_at=job.started_at,
@@ -170,9 +163,11 @@ class TorchtunePostTrainingImpl:
170
163
  resources_allocated=self._get_resources_allocated(job),
171
164
  )
172
165
 
173
- async def cancel_training_job(self, job_uuid: str) -> None:
174
- self._scheduler.cancel(job_uuid)
166
+ async def cancel_training_job(self, request: CancelTrainingJobRequest) -> None:
167
+ self._scheduler.cancel(request.job_uuid)
175
168
 
176
- async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse | None:
177
- job = self._scheduler.get_job(job_uuid)
178
- return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))
169
+ async def get_training_job_artifacts(
170
+ self, request: GetTrainingJobArtifactsRequest
171
+ ) -> PostTrainingJobArtifactsResponse | None:
172
+ job = self._scheduler.get_job(request.job_uuid)
173
+ return PostTrainingJobArtifactsResponse(job_uuid=request.job_uuid, checkpoints=self._get_checkpoints(job))
@@ -50,6 +50,7 @@ from llama_stack_api import (
50
50
  DataConfig,
51
51
  DatasetIO,
52
52
  Datasets,
53
+ IterRowsRequest,
53
54
  LoraFinetuningConfig,
54
55
  OptimizerConfig,
55
56
  PostTrainingMetric,
@@ -334,10 +335,7 @@ class LoraFinetuningSingleDevice:
334
335
  batch_size: int,
335
336
  ) -> tuple[DistributedSampler, DataLoader]:
336
337
  async def fetch_rows(dataset_id: str):
337
- return await self.datasetio_api.iterrows(
338
- dataset_id=dataset_id,
339
- limit=-1,
340
- )
338
+ return await self.datasetio_api.iterrows(IterRowsRequest(dataset_id=dataset_id, limit=-1))
341
339
 
342
340
  all_rows = await fetch_rows(dataset_id)
343
341
  rows = all_rows.data
@@ -5,7 +5,7 @@
5
5
  # the root directory of this source tree.
6
6
 
7
7
  import uuid
8
- from typing import TYPE_CHECKING, Any
8
+ from typing import TYPE_CHECKING
9
9
 
10
10
  if TYPE_CHECKING:
11
11
  from codeshield.cs import CodeShieldScanResult
@@ -15,9 +15,11 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
15
15
  interleaved_content_as_str,
16
16
  )
17
17
  from llama_stack_api import (
18
+ GetShieldRequest,
18
19
  ModerationObject,
19
20
  ModerationObjectResults,
20
- OpenAIMessageParam,
21
+ RunModerationRequest,
22
+ RunShieldRequest,
21
23
  RunShieldResponse,
22
24
  Safety,
23
25
  SafetyViolation,
@@ -51,19 +53,14 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
51
53
  f"Unsupported Code Scanner ID: {shield.provider_resource_id}. Allowed IDs: {ALLOWED_CODE_SCANNER_MODEL_IDS}"
52
54
  )
53
55
 
54
- async def run_shield(
55
- self,
56
- shield_id: str,
57
- messages: list[OpenAIMessageParam],
58
- params: dict[str, Any] = None,
59
- ) -> RunShieldResponse:
60
- shield = await self.shield_store.get_shield(shield_id)
56
+ async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse:
57
+ shield = await self.shield_store.get_shield(GetShieldRequest(identifier=request.shield_id))
61
58
  if not shield:
62
- raise ValueError(f"Shield {shield_id} not found")
59
+ raise ValueError(f"Shield {request.shield_id} not found")
63
60
 
64
61
  from codeshield.cs import CodeShield
65
62
 
66
- text = "\n".join([interleaved_content_as_str(m.content) for m in messages])
63
+ text = "\n".join([interleaved_content_as_str(m.content) for m in request.messages])
67
64
  log.info(f"Running CodeScannerShield on {text[50:]}")
68
65
  result = await CodeShield.scan_code(text)
69
66
 
@@ -102,11 +99,11 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
102
99
  metadata=metadata,
103
100
  )
104
101
 
105
- async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
106
- if model is None:
102
+ async def run_moderation(self, request: RunModerationRequest) -> ModerationObject:
103
+ if request.model is None:
107
104
  raise ValueError("Code scanner moderation requires a model identifier.")
108
105
 
109
- inputs = input if isinstance(input, list) else [input]
106
+ inputs = request.input if isinstance(request.input, list) else [request.input]
110
107
  results = []
111
108
 
112
109
  from codeshield.cs import CodeShield
@@ -129,4 +126,4 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
129
126
  )
130
127
  results.append(moderation_result)
131
128
 
132
- return ModerationObject(id=str(uuid.uuid4()), model=model, results=results)
129
+ return ModerationObject(id=str(uuid.uuid4()), model=request.model, results=results)
@@ -7,16 +7,15 @@
7
7
  import re
8
8
  import uuid
9
9
  from string import Template
10
- from typing import Any
11
10
 
12
11
  from llama_stack.core.datatypes import Api
13
12
  from llama_stack.log import get_logger
14
13
  from llama_stack.models.llama.datatypes import Role
15
- from llama_stack.models.llama.sku_types import CoreModelId
16
14
  from llama_stack.providers.utils.inference.prompt_adapter import (
17
15
  interleaved_content_as_str,
18
16
  )
19
17
  from llama_stack_api import (
18
+ GetShieldRequest,
20
19
  ImageContentItem,
21
20
  Inference,
22
21
  ModerationObject,
@@ -24,6 +23,8 @@ from llama_stack_api import (
24
23
  OpenAIChatCompletionRequestWithExtraBody,
25
24
  OpenAIMessageParam,
26
25
  OpenAIUserMessageParam,
26
+ RunModerationRequest,
27
+ RunShieldRequest,
27
28
  RunShieldResponse,
28
29
  Safety,
29
30
  SafetyViolation,
@@ -91,13 +92,13 @@ DEFAULT_LG_V3_SAFETY_CATEGORIES = [
91
92
 
92
93
  # accept both CoreModelId and huggingface repo id
93
94
  LLAMA_GUARD_MODEL_IDS = {
94
- CoreModelId.llama_guard_3_8b.value: "meta-llama/Llama-Guard-3-8B",
95
+ "Llama-Guard-3-8B": "meta-llama/Llama-Guard-3-8B",
95
96
  "meta-llama/Llama-Guard-3-8B": "meta-llama/Llama-Guard-3-8B",
96
- CoreModelId.llama_guard_3_1b.value: "meta-llama/Llama-Guard-3-1B",
97
+ "Llama-Guard-3-1B": "meta-llama/Llama-Guard-3-1B",
97
98
  "meta-llama/Llama-Guard-3-1B": "meta-llama/Llama-Guard-3-1B",
98
- CoreModelId.llama_guard_3_11b_vision.value: "meta-llama/Llama-Guard-3-11B-Vision",
99
+ "Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision",
99
100
  "meta-llama/Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision",
100
- CoreModelId.llama_guard_4_12b.value: "meta-llama/Llama-Guard-4-12B",
101
+ "Llama-Guard-4-12B": "meta-llama/Llama-Guard-4-12B",
101
102
  "meta-llama/Llama-Guard-4-12B": "meta-llama/Llama-Guard-4-12B",
102
103
  }
103
104
 
@@ -161,17 +162,12 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
161
162
  # The routing table handles the removal from the registry
162
163
  pass
163
164
 
164
- async def run_shield(
165
- self,
166
- shield_id: str,
167
- messages: list[OpenAIMessageParam],
168
- params: dict[str, Any] = None,
169
- ) -> RunShieldResponse:
170
- shield = await self.shield_store.get_shield(shield_id)
165
+ async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse:
166
+ shield = await self.shield_store.get_shield(GetShieldRequest(identifier=request.shield_id))
171
167
  if not shield:
172
- raise ValueError(f"Unknown shield {shield_id}")
168
+ raise ValueError(f"Unknown shield {request.shield_id}")
173
169
 
174
- messages = messages.copy()
170
+ messages = request.messages.copy()
175
171
  # some shields like llama-guard require the first message to be a user message
176
172
  # since this might be a tool call, first role might not be user
177
173
  if len(messages) > 0 and messages[0].role != "user":
@@ -200,30 +196,30 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
200
196
 
201
197
  return await impl.run(messages)
202
198
 
203
- async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
204
- if model is None:
199
+ async def run_moderation(self, request: RunModerationRequest) -> ModerationObject:
200
+ if request.model is None:
205
201
  raise ValueError("Llama Guard moderation requires a model identifier.")
206
202
 
207
- if isinstance(input, list):
208
- messages = input.copy()
203
+ if isinstance(request.input, list):
204
+ messages = request.input.copy()
209
205
  else:
210
- messages = [input]
206
+ messages = [request.input]
211
207
 
212
208
  # convert to user messages format with role
213
209
  messages = [OpenAIUserMessageParam(content=m) for m in messages]
214
210
 
215
211
  # Determine safety categories based on the model type
216
212
  # For known Llama Guard models, use specific categories
217
- if model in LLAMA_GUARD_MODEL_IDS:
213
+ if request.model in LLAMA_GUARD_MODEL_IDS:
218
214
  # Use the mapped model for categories but the original model_id for inference
219
- mapped_model = LLAMA_GUARD_MODEL_IDS[model]
215
+ mapped_model = LLAMA_GUARD_MODEL_IDS[request.model]
220
216
  safety_categories = MODEL_TO_SAFETY_CATEGORIES_MAP.get(mapped_model, DEFAULT_LG_V3_SAFETY_CATEGORIES)
221
217
  else:
222
218
  # For unknown models, use default Llama Guard 3 8B categories
223
219
  safety_categories = DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE]
224
220
 
225
221
  impl = LlamaGuardShield(
226
- model=model,
222
+ model=request.model,
227
223
  inference_api=self.inference_api,
228
224
  excluded_categories=self.config.excluded_categories,
229
225
  safety_categories=safety_categories,
@@ -293,7 +289,7 @@ class LlamaGuardShield:
293
289
  async def run(self, messages: list[OpenAIMessageParam]) -> RunShieldResponse:
294
290
  messages = self.validate_messages(messages)
295
291
 
296
- if self.model == CoreModelId.llama_guard_3_11b_vision.value:
292
+ if self.model == "Llama-Guard-3-11B-Vision":
297
293
  shield_input_message = self.build_vision_shield_input(messages)
298
294
  else:
299
295
  shield_input_message = self.build_text_shield_input(messages)
@@ -4,17 +4,19 @@
4
4
  # This source code is licensed under the terms described in the LICENSE file in
5
5
  # the root directory of this source tree.
6
6
 
7
- from typing import Any
8
-
9
7
  import torch
10
8
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
11
9
 
12
10
  from llama_stack.core.utils.model_utils import model_local_dir
13
11
  from llama_stack.log import get_logger
14
- from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
12
+ from llama_stack.providers.utils.inference.prompt_adapter import (
13
+ interleaved_content_as_str,
14
+ )
15
+ from llama_stack.providers.utils.safety import ShieldToModerationMixin
15
16
  from llama_stack_api import (
16
- ModerationObject,
17
+ GetShieldRequest,
17
18
  OpenAIMessageParam,
19
+ RunShieldRequest,
18
20
  RunShieldResponse,
19
21
  Safety,
20
22
  SafetyViolation,
@@ -31,7 +33,7 @@ log = get_logger(name=__name__, category="safety")
31
33
  PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
32
34
 
33
35
 
34
- class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
36
+ class PromptGuardSafetyImpl(ShieldToModerationMixin, Safety, ShieldsProtocolPrivate):
35
37
  shield_store: ShieldStore
36
38
 
37
39
  def __init__(self, config: PromptGuardConfig, _deps) -> None:
@@ -51,20 +53,12 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
51
53
  async def unregister_shield(self, identifier: str) -> None:
52
54
  pass
53
55
 
54
- async def run_shield(
55
- self,
56
- shield_id: str,
57
- messages: list[OpenAIMessageParam],
58
- params: dict[str, Any],
59
- ) -> RunShieldResponse:
60
- shield = await self.shield_store.get_shield(shield_id)
56
+ async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse:
57
+ shield = await self.shield_store.get_shield(GetShieldRequest(identifier=request.shield_id))
61
58
  if not shield:
62
- raise ValueError(f"Unknown shield {shield_id}")
63
-
64
- return await self.shield.run(messages)
59
+ raise ValueError(f"Unknown shield {request.shield_id}")
65
60
 
66
- async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
67
- raise NotImplementedError("run_moderation is not implemented for Prompt Guard")
61
+ return await self.shield.run(request.messages)
68
62
 
69
63
 
70
64
  class PromptGuardShield:
@@ -3,16 +3,17 @@
3
3
  #
4
4
  # This source code is licensed under the terms described in the LICENSE file in
5
5
  # the root directory of this source tree.
6
- from typing import Any
7
6
 
8
7
  from llama_stack_api import (
9
8
  DatasetIO,
10
9
  Datasets,
10
+ IterRowsRequest,
11
+ ScoreBatchRequest,
11
12
  ScoreBatchResponse,
13
+ ScoreRequest,
12
14
  ScoreResponse,
13
15
  Scoring,
14
16
  ScoringFn,
15
- ScoringFnParams,
16
17
  ScoringFunctionsProtocolPrivate,
17
18
  ScoringResult,
18
19
  )
@@ -75,19 +76,15 @@ class BasicScoringImpl(
75
76
 
76
77
  async def score_batch(
77
78
  self,
78
- dataset_id: str,
79
- scoring_functions: dict[str, ScoringFnParams | None] = None,
80
- save_results_dataset: bool = False,
79
+ request: ScoreBatchRequest,
81
80
  ) -> ScoreBatchResponse:
82
- all_rows = await self.datasetio_api.iterrows(
83
- dataset_id=dataset_id,
84
- limit=-1,
85
- )
86
- res = await self.score(
81
+ all_rows = await self.datasetio_api.iterrows(IterRowsRequest(dataset_id=request.dataset_id, limit=-1))
82
+ score_request = ScoreRequest(
87
83
  input_rows=all_rows.data,
88
- scoring_functions=scoring_functions,
84
+ scoring_functions=request.scoring_functions,
89
85
  )
90
- if save_results_dataset:
86
+ res = await self.score(score_request)
87
+ if request.save_results_dataset:
91
88
  # TODO: persist and register dataset on to server for reading
92
89
  # self.datasets_api.register_dataset()
93
90
  raise NotImplementedError("Save results dataset not implemented yet")
@@ -98,16 +95,15 @@ class BasicScoringImpl(
98
95
 
99
96
  async def score(
100
97
  self,
101
- input_rows: list[dict[str, Any]],
102
- scoring_functions: dict[str, ScoringFnParams | None] = None,
98
+ request: ScoreRequest,
103
99
  ) -> ScoreResponse:
104
100
  res = {}
105
- for scoring_fn_id in scoring_functions.keys():
101
+ for scoring_fn_id in request.scoring_functions.keys():
106
102
  if scoring_fn_id not in self.scoring_fn_id_impls:
107
103
  raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
108
104
  scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
109
- scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
110
- score_results = await scoring_fn.score(input_rows, scoring_fn_id, scoring_fn_params)
105
+ scoring_fn_params = request.scoring_functions.get(scoring_fn_id, None)
106
+ score_results = await scoring_fn.score(request.input_rows, scoring_fn_id, scoring_fn_params)
111
107
  agg_results = await scoring_fn.aggregate(score_results, scoring_fn_id, scoring_fn_params)
112
108
  res[scoring_fn_id] = ScoringResult(
113
109
  score_rows=score_results,