llama-stack 0.4.3__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (307) hide show
  1. llama_stack/cli/stack/_list_deps.py +11 -7
  2. llama_stack/cli/stack/run.py +3 -25
  3. llama_stack/core/access_control/datatypes.py +78 -0
  4. llama_stack/core/configure.py +2 -2
  5. {llama_stack_api/internal → llama_stack/core/connectors}/__init__.py +2 -2
  6. llama_stack/core/connectors/connectors.py +162 -0
  7. llama_stack/core/conversations/conversations.py +61 -58
  8. llama_stack/core/datatypes.py +54 -8
  9. llama_stack/core/library_client.py +60 -13
  10. llama_stack/core/prompts/prompts.py +43 -42
  11. llama_stack/core/routers/datasets.py +20 -17
  12. llama_stack/core/routers/eval_scoring.py +143 -53
  13. llama_stack/core/routers/inference.py +20 -9
  14. llama_stack/core/routers/safety.py +30 -42
  15. llama_stack/core/routers/vector_io.py +15 -7
  16. llama_stack/core/routing_tables/models.py +42 -3
  17. llama_stack/core/routing_tables/scoring_functions.py +19 -19
  18. llama_stack/core/routing_tables/shields.py +20 -17
  19. llama_stack/core/routing_tables/vector_stores.py +8 -5
  20. llama_stack/core/server/auth.py +192 -17
  21. llama_stack/core/server/fastapi_router_registry.py +40 -5
  22. llama_stack/core/server/server.py +24 -5
  23. llama_stack/core/stack.py +54 -10
  24. llama_stack/core/storage/datatypes.py +9 -0
  25. llama_stack/core/store/registry.py +1 -1
  26. llama_stack/core/utils/exec.py +2 -2
  27. llama_stack/core/utils/type_inspection.py +16 -2
  28. llama_stack/distributions/dell/config.yaml +4 -1
  29. llama_stack/distributions/dell/doc_template.md +209 -0
  30. llama_stack/distributions/dell/run-with-safety.yaml +4 -1
  31. llama_stack/distributions/nvidia/config.yaml +4 -1
  32. llama_stack/distributions/nvidia/doc_template.md +170 -0
  33. llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
  34. llama_stack/distributions/oci/config.yaml +4 -1
  35. llama_stack/distributions/oci/doc_template.md +140 -0
  36. llama_stack/distributions/open-benchmark/config.yaml +9 -1
  37. llama_stack/distributions/postgres-demo/config.yaml +1 -1
  38. llama_stack/distributions/starter/build.yaml +62 -0
  39. llama_stack/distributions/starter/config.yaml +22 -3
  40. llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
  41. llama_stack/distributions/starter/starter.py +13 -1
  42. llama_stack/distributions/starter-gpu/build.yaml +62 -0
  43. llama_stack/distributions/starter-gpu/config.yaml +22 -3
  44. llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
  45. llama_stack/distributions/template.py +10 -2
  46. llama_stack/distributions/watsonx/config.yaml +4 -1
  47. llama_stack/log.py +1 -0
  48. llama_stack/models/llama/resources/dog.jpg +0 -0
  49. llama_stack/models/llama/resources/pasta.jpeg +0 -0
  50. llama_stack/models/llama/resources/small_dog.jpg +0 -0
  51. llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
  52. llama_stack/providers/inline/agents/meta_reference/agents.py +57 -61
  53. llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +183 -60
  54. llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +94 -22
  55. llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
  56. llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
  57. llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
  58. llama_stack/providers/inline/batches/reference/batches.py +2 -1
  59. llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
  60. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.h +9 -0
  61. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.swift +189 -0
  62. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/Parsing.swift +238 -0
  63. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/PromptTemplate.swift +12 -0
  64. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/SystemPrompts.swift +89 -0
  65. llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.pbxproj +550 -0
  66. llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/contents.xcworkspacedata +7 -0
  67. llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +8 -0
  68. llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
  69. llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
  70. llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
  71. llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
  72. llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
  73. llama_stack/providers/inline/safety/llama_guard/llama_guard.py +15 -18
  74. llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
  75. llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
  76. llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
  77. llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
  78. llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
  79. llama_stack/providers/registry/agents.py +1 -0
  80. llama_stack/providers/registry/inference.py +1 -9
  81. llama_stack/providers/registry/vector_io.py +136 -16
  82. llama_stack/providers/remote/datasetio/nvidia/README.md +74 -0
  83. llama_stack/providers/remote/eval/nvidia/README.md +134 -0
  84. llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
  85. llama_stack/providers/remote/files/s3/README.md +266 -0
  86. llama_stack/providers/remote/files/s3/config.py +5 -3
  87. llama_stack/providers/remote/files/s3/files.py +2 -2
  88. llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
  89. llama_stack/providers/remote/inference/nvidia/NVIDIA.md +203 -0
  90. llama_stack/providers/remote/inference/openai/openai.py +2 -0
  91. llama_stack/providers/remote/inference/together/together.py +4 -0
  92. llama_stack/providers/remote/inference/vertexai/config.py +3 -3
  93. llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
  94. llama_stack/providers/remote/inference/vllm/config.py +37 -18
  95. llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
  96. llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
  97. llama_stack/providers/remote/post_training/nvidia/README.md +151 -0
  98. llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
  99. llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
  100. llama_stack/providers/remote/safety/nvidia/README.md +78 -0
  101. llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
  102. llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
  103. llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
  104. llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
  105. llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
  106. llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
  107. llama_stack/providers/remote/vector_io/oci/config.py +41 -0
  108. llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
  109. llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
  110. llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
  111. llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
  112. llama_stack/providers/utils/bedrock/client.py +3 -3
  113. llama_stack/providers/utils/bedrock/config.py +7 -7
  114. llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
  115. llama_stack/providers/utils/inference/http_client.py +239 -0
  116. llama_stack/providers/utils/inference/litellm_openai_mixin.py +5 -0
  117. llama_stack/providers/utils/inference/model_registry.py +148 -2
  118. llama_stack/providers/utils/inference/openai_compat.py +2 -1
  119. llama_stack/providers/utils/inference/openai_mixin.py +41 -2
  120. llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
  121. llama_stack/providers/utils/memory/vector_store.py +46 -19
  122. llama_stack/providers/utils/responses/responses_store.py +40 -6
  123. llama_stack/providers/utils/safety.py +114 -0
  124. llama_stack/providers/utils/tools/mcp.py +44 -3
  125. llama_stack/testing/api_recorder.py +9 -3
  126. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/METADATA +14 -2
  127. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/RECORD +131 -275
  128. llama_stack-0.5.0rc1.dist-info/top_level.txt +1 -0
  129. llama_stack/distributions/meta-reference-gpu/__init__.py +0 -7
  130. llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
  131. llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
  132. llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
  133. llama_stack/models/llama/hadamard_utils.py +0 -88
  134. llama_stack/models/llama/llama3/args.py +0 -74
  135. llama_stack/models/llama/llama3/generation.py +0 -378
  136. llama_stack/models/llama/llama3/model.py +0 -304
  137. llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
  138. llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
  139. llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
  140. llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
  141. llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
  142. llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
  143. llama_stack/models/llama/llama3/quantization/loader.py +0 -316
  144. llama_stack/models/llama/llama3_1/__init__.py +0 -12
  145. llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
  146. llama_stack/models/llama/llama3_1/prompts.py +0 -258
  147. llama_stack/models/llama/llama3_2/__init__.py +0 -5
  148. llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
  149. llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
  150. llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
  151. llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
  152. llama_stack/models/llama/llama3_3/__init__.py +0 -5
  153. llama_stack/models/llama/llama3_3/prompts.py +0 -259
  154. llama_stack/models/llama/llama4/args.py +0 -107
  155. llama_stack/models/llama/llama4/ffn.py +0 -58
  156. llama_stack/models/llama/llama4/moe.py +0 -214
  157. llama_stack/models/llama/llama4/preprocess.py +0 -435
  158. llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
  159. llama_stack/models/llama/llama4/quantization/loader.py +0 -226
  160. llama_stack/models/llama/llama4/vision/__init__.py +0 -5
  161. llama_stack/models/llama/llama4/vision/embedding.py +0 -210
  162. llama_stack/models/llama/llama4/vision/encoder.py +0 -412
  163. llama_stack/models/llama/quantize_impls.py +0 -316
  164. llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
  165. llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
  166. llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
  167. llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
  168. llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
  169. llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
  170. llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
  171. llama_stack-0.4.3.dist-info/top_level.txt +0 -2
  172. llama_stack_api/__init__.py +0 -945
  173. llama_stack_api/admin/__init__.py +0 -45
  174. llama_stack_api/admin/api.py +0 -72
  175. llama_stack_api/admin/fastapi_routes.py +0 -117
  176. llama_stack_api/admin/models.py +0 -113
  177. llama_stack_api/agents.py +0 -173
  178. llama_stack_api/batches/__init__.py +0 -40
  179. llama_stack_api/batches/api.py +0 -53
  180. llama_stack_api/batches/fastapi_routes.py +0 -113
  181. llama_stack_api/batches/models.py +0 -78
  182. llama_stack_api/benchmarks/__init__.py +0 -43
  183. llama_stack_api/benchmarks/api.py +0 -39
  184. llama_stack_api/benchmarks/fastapi_routes.py +0 -109
  185. llama_stack_api/benchmarks/models.py +0 -109
  186. llama_stack_api/common/__init__.py +0 -5
  187. llama_stack_api/common/content_types.py +0 -101
  188. llama_stack_api/common/errors.py +0 -95
  189. llama_stack_api/common/job_types.py +0 -38
  190. llama_stack_api/common/responses.py +0 -77
  191. llama_stack_api/common/training_types.py +0 -47
  192. llama_stack_api/common/type_system.py +0 -146
  193. llama_stack_api/connectors.py +0 -146
  194. llama_stack_api/conversations.py +0 -270
  195. llama_stack_api/datasetio.py +0 -55
  196. llama_stack_api/datasets/__init__.py +0 -61
  197. llama_stack_api/datasets/api.py +0 -35
  198. llama_stack_api/datasets/fastapi_routes.py +0 -104
  199. llama_stack_api/datasets/models.py +0 -152
  200. llama_stack_api/datatypes.py +0 -373
  201. llama_stack_api/eval.py +0 -137
  202. llama_stack_api/file_processors/__init__.py +0 -27
  203. llama_stack_api/file_processors/api.py +0 -64
  204. llama_stack_api/file_processors/fastapi_routes.py +0 -78
  205. llama_stack_api/file_processors/models.py +0 -42
  206. llama_stack_api/files/__init__.py +0 -35
  207. llama_stack_api/files/api.py +0 -51
  208. llama_stack_api/files/fastapi_routes.py +0 -124
  209. llama_stack_api/files/models.py +0 -107
  210. llama_stack_api/inference.py +0 -1169
  211. llama_stack_api/inspect_api/__init__.py +0 -37
  212. llama_stack_api/inspect_api/api.py +0 -25
  213. llama_stack_api/inspect_api/fastapi_routes.py +0 -76
  214. llama_stack_api/inspect_api/models.py +0 -28
  215. llama_stack_api/internal/kvstore.py +0 -28
  216. llama_stack_api/internal/sqlstore.py +0 -81
  217. llama_stack_api/llama_stack_api/__init__.py +0 -945
  218. llama_stack_api/llama_stack_api/admin/__init__.py +0 -45
  219. llama_stack_api/llama_stack_api/admin/api.py +0 -72
  220. llama_stack_api/llama_stack_api/admin/fastapi_routes.py +0 -117
  221. llama_stack_api/llama_stack_api/admin/models.py +0 -113
  222. llama_stack_api/llama_stack_api/agents.py +0 -173
  223. llama_stack_api/llama_stack_api/batches/__init__.py +0 -40
  224. llama_stack_api/llama_stack_api/batches/api.py +0 -53
  225. llama_stack_api/llama_stack_api/batches/fastapi_routes.py +0 -113
  226. llama_stack_api/llama_stack_api/batches/models.py +0 -78
  227. llama_stack_api/llama_stack_api/benchmarks/__init__.py +0 -43
  228. llama_stack_api/llama_stack_api/benchmarks/api.py +0 -39
  229. llama_stack_api/llama_stack_api/benchmarks/fastapi_routes.py +0 -109
  230. llama_stack_api/llama_stack_api/benchmarks/models.py +0 -109
  231. llama_stack_api/llama_stack_api/common/__init__.py +0 -5
  232. llama_stack_api/llama_stack_api/common/content_types.py +0 -101
  233. llama_stack_api/llama_stack_api/common/errors.py +0 -95
  234. llama_stack_api/llama_stack_api/common/job_types.py +0 -38
  235. llama_stack_api/llama_stack_api/common/responses.py +0 -77
  236. llama_stack_api/llama_stack_api/common/training_types.py +0 -47
  237. llama_stack_api/llama_stack_api/common/type_system.py +0 -146
  238. llama_stack_api/llama_stack_api/connectors.py +0 -146
  239. llama_stack_api/llama_stack_api/conversations.py +0 -270
  240. llama_stack_api/llama_stack_api/datasetio.py +0 -55
  241. llama_stack_api/llama_stack_api/datasets/__init__.py +0 -61
  242. llama_stack_api/llama_stack_api/datasets/api.py +0 -35
  243. llama_stack_api/llama_stack_api/datasets/fastapi_routes.py +0 -104
  244. llama_stack_api/llama_stack_api/datasets/models.py +0 -152
  245. llama_stack_api/llama_stack_api/datatypes.py +0 -373
  246. llama_stack_api/llama_stack_api/eval.py +0 -137
  247. llama_stack_api/llama_stack_api/file_processors/__init__.py +0 -27
  248. llama_stack_api/llama_stack_api/file_processors/api.py +0 -64
  249. llama_stack_api/llama_stack_api/file_processors/fastapi_routes.py +0 -78
  250. llama_stack_api/llama_stack_api/file_processors/models.py +0 -42
  251. llama_stack_api/llama_stack_api/files/__init__.py +0 -35
  252. llama_stack_api/llama_stack_api/files/api.py +0 -51
  253. llama_stack_api/llama_stack_api/files/fastapi_routes.py +0 -124
  254. llama_stack_api/llama_stack_api/files/models.py +0 -107
  255. llama_stack_api/llama_stack_api/inference.py +0 -1169
  256. llama_stack_api/llama_stack_api/inspect_api/__init__.py +0 -37
  257. llama_stack_api/llama_stack_api/inspect_api/api.py +0 -25
  258. llama_stack_api/llama_stack_api/inspect_api/fastapi_routes.py +0 -76
  259. llama_stack_api/llama_stack_api/inspect_api/models.py +0 -28
  260. llama_stack_api/llama_stack_api/internal/__init__.py +0 -9
  261. llama_stack_api/llama_stack_api/internal/kvstore.py +0 -28
  262. llama_stack_api/llama_stack_api/internal/sqlstore.py +0 -81
  263. llama_stack_api/llama_stack_api/models.py +0 -171
  264. llama_stack_api/llama_stack_api/openai_responses.py +0 -1468
  265. llama_stack_api/llama_stack_api/post_training.py +0 -370
  266. llama_stack_api/llama_stack_api/prompts.py +0 -203
  267. llama_stack_api/llama_stack_api/providers/__init__.py +0 -33
  268. llama_stack_api/llama_stack_api/providers/api.py +0 -16
  269. llama_stack_api/llama_stack_api/providers/fastapi_routes.py +0 -57
  270. llama_stack_api/llama_stack_api/providers/models.py +0 -24
  271. llama_stack_api/llama_stack_api/py.typed +0 -0
  272. llama_stack_api/llama_stack_api/rag_tool.py +0 -168
  273. llama_stack_api/llama_stack_api/resource.py +0 -37
  274. llama_stack_api/llama_stack_api/router_utils.py +0 -160
  275. llama_stack_api/llama_stack_api/safety.py +0 -132
  276. llama_stack_api/llama_stack_api/schema_utils.py +0 -208
  277. llama_stack_api/llama_stack_api/scoring.py +0 -93
  278. llama_stack_api/llama_stack_api/scoring_functions.py +0 -211
  279. llama_stack_api/llama_stack_api/shields.py +0 -93
  280. llama_stack_api/llama_stack_api/tools.py +0 -226
  281. llama_stack_api/llama_stack_api/vector_io.py +0 -941
  282. llama_stack_api/llama_stack_api/vector_stores.py +0 -53
  283. llama_stack_api/llama_stack_api/version.py +0 -9
  284. llama_stack_api/models.py +0 -171
  285. llama_stack_api/openai_responses.py +0 -1468
  286. llama_stack_api/post_training.py +0 -370
  287. llama_stack_api/prompts.py +0 -203
  288. llama_stack_api/providers/__init__.py +0 -33
  289. llama_stack_api/providers/api.py +0 -16
  290. llama_stack_api/providers/fastapi_routes.py +0 -57
  291. llama_stack_api/providers/models.py +0 -24
  292. llama_stack_api/py.typed +0 -0
  293. llama_stack_api/rag_tool.py +0 -168
  294. llama_stack_api/resource.py +0 -37
  295. llama_stack_api/router_utils.py +0 -160
  296. llama_stack_api/safety.py +0 -132
  297. llama_stack_api/schema_utils.py +0 -208
  298. llama_stack_api/scoring.py +0 -93
  299. llama_stack_api/scoring_functions.py +0 -211
  300. llama_stack_api/shields.py +0 -93
  301. llama_stack_api/tools.py +0 -226
  302. llama_stack_api/vector_io.py +0 -941
  303. llama_stack_api/vector_stores.py +0 -53
  304. llama_stack_api/version.py +0 -9
  305. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/WHEEL +0 -0
  306. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/entry_points.txt +0 -0
  307. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
@@ -12,17 +12,19 @@ from llama_stack.providers.inline.post_training.huggingface.config import (
12
12
  from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
13
13
  from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
14
14
  from llama_stack_api import (
15
- AlgorithmConfig,
15
+ CancelTrainingJobRequest,
16
16
  Checkpoint,
17
17
  DatasetIO,
18
18
  Datasets,
19
- DPOAlignmentConfig,
19
+ GetTrainingJobArtifactsRequest,
20
+ GetTrainingJobStatusRequest,
20
21
  JobStatus,
21
22
  ListPostTrainingJobsResponse,
22
23
  PostTrainingJob,
23
24
  PostTrainingJobArtifactsResponse,
24
25
  PostTrainingJobStatusResponse,
25
- TrainingConfig,
26
+ PreferenceOptimizeRequest,
27
+ SupervisedFineTuneRequest,
26
28
  )
27
29
 
28
30
 
@@ -69,13 +71,7 @@ class HuggingFacePostTrainingImpl:
69
71
 
70
72
  async def supervised_fine_tune(
71
73
  self,
72
- job_uuid: str,
73
- training_config: TrainingConfig,
74
- hyperparam_search_config: dict[str, Any],
75
- logger_config: dict[str, Any],
76
- model: str,
77
- checkpoint_dir: str | None = None,
78
- algorithm_config: AlgorithmConfig | None = None,
74
+ request: SupervisedFineTuneRequest,
79
75
  ) -> PostTrainingJob:
80
76
  async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
81
77
  from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import (
@@ -85,17 +81,17 @@ class HuggingFacePostTrainingImpl:
85
81
  on_log_message_cb("Starting HF finetuning")
86
82
 
87
83
  recipe = HFFinetuningSingleDevice(
88
- job_uuid=job_uuid,
84
+ job_uuid=request.job_uuid,
89
85
  datasetio_api=self.datasetio_api,
90
86
  datasets_api=self.datasets_api,
91
87
  )
92
88
 
93
89
  resources_allocated, checkpoints = await recipe.train(
94
- model=model,
95
- output_dir=checkpoint_dir,
96
- job_uuid=job_uuid,
97
- lora_config=algorithm_config,
98
- config=training_config,
90
+ model=request.model,
91
+ output_dir=request.checkpoint_dir,
92
+ job_uuid=request.job_uuid,
93
+ lora_config=request.algorithm_config,
94
+ config=request.training_config,
99
95
  provider_config=self.config,
100
96
  )
101
97
 
@@ -108,17 +104,12 @@ class HuggingFacePostTrainingImpl:
108
104
  on_status_change_cb(SchedulerJobStatus.completed)
109
105
  on_log_message_cb("HF finetuning completed")
110
106
 
111
- job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
107
+ job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, request.job_uuid, handler)
112
108
  return PostTrainingJob(job_uuid=job_uuid)
113
109
 
114
110
  async def preference_optimize(
115
111
  self,
116
- job_uuid: str,
117
- finetuned_model: str,
118
- algorithm_config: DPOAlignmentConfig,
119
- training_config: TrainingConfig,
120
- hyperparam_search_config: dict[str, Any],
121
- logger_config: dict[str, Any],
112
+ request: PreferenceOptimizeRequest,
122
113
  ) -> PostTrainingJob:
123
114
  async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
124
115
  from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device_dpo import (
@@ -128,17 +119,17 @@ class HuggingFacePostTrainingImpl:
128
119
  on_log_message_cb("Starting HF DPO alignment")
129
120
 
130
121
  recipe = HFDPOAlignmentSingleDevice(
131
- job_uuid=job_uuid,
122
+ job_uuid=request.job_uuid,
132
123
  datasetio_api=self.datasetio_api,
133
124
  datasets_api=self.datasets_api,
134
125
  )
135
126
 
136
127
  resources_allocated, checkpoints = await recipe.train(
137
- model=finetuned_model,
138
- output_dir=f"{self.config.dpo_output_dir}/{job_uuid}",
139
- job_uuid=job_uuid,
140
- dpo_config=algorithm_config,
141
- config=training_config,
128
+ model=request.finetuned_model,
129
+ output_dir=f"{self.config.dpo_output_dir}/{request.job_uuid}",
130
+ job_uuid=request.job_uuid,
131
+ dpo_config=request.algorithm_config,
132
+ config=request.training_config,
142
133
  provider_config=self.config,
143
134
  )
144
135
 
@@ -153,7 +144,7 @@ class HuggingFacePostTrainingImpl:
153
144
  on_status_change_cb(SchedulerJobStatus.completed)
154
145
  on_log_message_cb("HF DPO alignment completed")
155
146
 
156
- job_uuid = self._scheduler.schedule(_JOB_TYPE_DPO_TRAINING, job_uuid, handler)
147
+ job_uuid = self._scheduler.schedule(_JOB_TYPE_DPO_TRAINING, request.job_uuid, handler)
157
148
  return PostTrainingJob(job_uuid=job_uuid)
158
149
 
159
150
  @staticmethod
@@ -169,8 +160,10 @@ class HuggingFacePostTrainingImpl:
169
160
  data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
170
161
  return data[0] if data else None
171
162
 
172
- async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse | None:
173
- job = self._scheduler.get_job(job_uuid)
163
+ async def get_training_job_status(
164
+ self, request: GetTrainingJobStatusRequest
165
+ ) -> PostTrainingJobStatusResponse | None:
166
+ job = self._scheduler.get_job(request.job_uuid)
174
167
 
175
168
  match job.status:
176
169
  # TODO: Add support for other statuses to API
@@ -186,7 +179,7 @@ class HuggingFacePostTrainingImpl:
186
179
  raise NotImplementedError()
187
180
 
188
181
  return PostTrainingJobStatusResponse(
189
- job_uuid=job_uuid,
182
+ job_uuid=request.job_uuid,
190
183
  status=status,
191
184
  scheduled_at=job.scheduled_at,
192
185
  started_at=job.started_at,
@@ -195,12 +188,14 @@ class HuggingFacePostTrainingImpl:
195
188
  resources_allocated=self._get_resources_allocated(job),
196
189
  )
197
190
 
198
- async def cancel_training_job(self, job_uuid: str) -> None:
199
- self._scheduler.cancel(job_uuid)
191
+ async def cancel_training_job(self, request: CancelTrainingJobRequest) -> None:
192
+ self._scheduler.cancel(request.job_uuid)
200
193
 
201
- async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse | None:
202
- job = self._scheduler.get_job(job_uuid)
203
- return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))
194
+ async def get_training_job_artifacts(
195
+ self, request: GetTrainingJobArtifactsRequest
196
+ ) -> PostTrainingJobArtifactsResponse | None:
197
+ job = self._scheduler.get_job(request.job_uuid)
198
+ return PostTrainingJobArtifactsResponse(job_uuid=request.job_uuid, checkpoints=self._get_checkpoints(job))
204
199
 
205
200
  async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
206
201
  return ListPostTrainingJobsResponse(
@@ -16,7 +16,7 @@ import torch
16
16
  from datasets import Dataset
17
17
  from transformers import AutoConfig, AutoModelForCausalLM
18
18
 
19
- from llama_stack_api import Checkpoint, DatasetIO, TrainingConfig
19
+ from llama_stack_api import Checkpoint, DatasetIO, IterRowsRequest, TrainingConfig
20
20
 
21
21
  if TYPE_CHECKING:
22
22
  from transformers import PretrainedConfig
@@ -135,10 +135,7 @@ def setup_torch_device(device_str: str) -> torch.device:
135
135
  async def load_rows_from_dataset(datasetio_api: DatasetIO, dataset_id: str) -> list[dict[str, Any]]:
136
136
  """Load dataset from llama stack dataset provider"""
137
137
  try:
138
- all_rows = await datasetio_api.iterrows(
139
- dataset_id=dataset_id,
140
- limit=-1,
141
- )
138
+ all_rows = await datasetio_api.iterrows(IterRowsRequest(dataset_id=dataset_id, limit=-1))
142
139
  if not isinstance(all_rows.data, list):
143
140
  raise RuntimeError("Expected dataset data to be a list")
144
141
  return all_rows.data
@@ -12,18 +12,20 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
12
12
  from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
13
13
  from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
14
14
  from llama_stack_api import (
15
- AlgorithmConfig,
15
+ CancelTrainingJobRequest,
16
16
  Checkpoint,
17
17
  DatasetIO,
18
18
  Datasets,
19
- DPOAlignmentConfig,
19
+ GetTrainingJobArtifactsRequest,
20
+ GetTrainingJobStatusRequest,
20
21
  JobStatus,
21
22
  ListPostTrainingJobsResponse,
22
23
  LoraFinetuningConfig,
23
24
  PostTrainingJob,
24
25
  PostTrainingJobArtifactsResponse,
25
26
  PostTrainingJobStatusResponse,
26
- TrainingConfig,
27
+ PreferenceOptimizeRequest,
28
+ SupervisedFineTuneRequest,
27
29
  )
28
30
 
29
31
 
@@ -69,15 +71,9 @@ class TorchtunePostTrainingImpl:
69
71
 
70
72
  async def supervised_fine_tune(
71
73
  self,
72
- job_uuid: str,
73
- training_config: TrainingConfig,
74
- hyperparam_search_config: dict[str, Any],
75
- logger_config: dict[str, Any],
76
- model: str,
77
- checkpoint_dir: str | None,
78
- algorithm_config: AlgorithmConfig | None,
74
+ request: SupervisedFineTuneRequest,
79
75
  ) -> PostTrainingJob:
80
- if isinstance(algorithm_config, LoraFinetuningConfig):
76
+ if isinstance(request.algorithm_config, LoraFinetuningConfig):
81
77
 
82
78
  async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
83
79
  from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
@@ -88,13 +84,13 @@ class TorchtunePostTrainingImpl:
88
84
 
89
85
  recipe = LoraFinetuningSingleDevice(
90
86
  self.config,
91
- job_uuid,
92
- training_config,
93
- hyperparam_search_config,
94
- logger_config,
95
- model,
96
- checkpoint_dir,
97
- algorithm_config,
87
+ request.job_uuid,
88
+ request.training_config,
89
+ request.hyperparam_search_config,
90
+ request.logger_config,
91
+ request.model,
92
+ request.checkpoint_dir,
93
+ request.algorithm_config,
98
94
  self.datasetio_api,
99
95
  self.datasets_api,
100
96
  )
@@ -112,17 +108,12 @@ class TorchtunePostTrainingImpl:
112
108
  else:
113
109
  raise NotImplementedError()
114
110
 
115
- job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
111
+ job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, request.job_uuid, handler)
116
112
  return PostTrainingJob(job_uuid=job_uuid)
117
113
 
118
114
  async def preference_optimize(
119
115
  self,
120
- job_uuid: str,
121
- finetuned_model: str,
122
- algorithm_config: DPOAlignmentConfig,
123
- training_config: TrainingConfig,
124
- hyperparam_search_config: dict[str, Any],
125
- logger_config: dict[str, Any],
116
+ request: PreferenceOptimizeRequest,
126
117
  ) -> PostTrainingJob:
127
118
  raise NotImplementedError()
128
119
 
@@ -144,8 +135,10 @@ class TorchtunePostTrainingImpl:
144
135
  data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
145
136
  return data[0] if data else None
146
137
 
147
- async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse | None:
148
- job = self._scheduler.get_job(job_uuid)
138
+ async def get_training_job_status(
139
+ self, request: GetTrainingJobStatusRequest
140
+ ) -> PostTrainingJobStatusResponse | None:
141
+ job = self._scheduler.get_job(request.job_uuid)
149
142
 
150
143
  match job.status:
151
144
  # TODO: Add support for other statuses to API
@@ -161,7 +154,7 @@ class TorchtunePostTrainingImpl:
161
154
  raise NotImplementedError()
162
155
 
163
156
  return PostTrainingJobStatusResponse(
164
- job_uuid=job_uuid,
157
+ job_uuid=request.job_uuid,
165
158
  status=status,
166
159
  scheduled_at=job.scheduled_at,
167
160
  started_at=job.started_at,
@@ -170,9 +163,11 @@ class TorchtunePostTrainingImpl:
170
163
  resources_allocated=self._get_resources_allocated(job),
171
164
  )
172
165
 
173
- async def cancel_training_job(self, job_uuid: str) -> None:
174
- self._scheduler.cancel(job_uuid)
166
+ async def cancel_training_job(self, request: CancelTrainingJobRequest) -> None:
167
+ self._scheduler.cancel(request.job_uuid)
175
168
 
176
- async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse | None:
177
- job = self._scheduler.get_job(job_uuid)
178
- return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))
169
+ async def get_training_job_artifacts(
170
+ self, request: GetTrainingJobArtifactsRequest
171
+ ) -> PostTrainingJobArtifactsResponse | None:
172
+ job = self._scheduler.get_job(request.job_uuid)
173
+ return PostTrainingJobArtifactsResponse(job_uuid=request.job_uuid, checkpoints=self._get_checkpoints(job))
@@ -50,6 +50,7 @@ from llama_stack_api import (
50
50
  DataConfig,
51
51
  DatasetIO,
52
52
  Datasets,
53
+ IterRowsRequest,
53
54
  LoraFinetuningConfig,
54
55
  OptimizerConfig,
55
56
  PostTrainingMetric,
@@ -334,10 +335,7 @@ class LoraFinetuningSingleDevice:
334
335
  batch_size: int,
335
336
  ) -> tuple[DistributedSampler, DataLoader]:
336
337
  async def fetch_rows(dataset_id: str):
337
- return await self.datasetio_api.iterrows(
338
- dataset_id=dataset_id,
339
- limit=-1,
340
- )
338
+ return await self.datasetio_api.iterrows(IterRowsRequest(dataset_id=dataset_id, limit=-1))
341
339
 
342
340
  all_rows = await fetch_rows(dataset_id)
343
341
  rows = all_rows.data
@@ -5,7 +5,7 @@
5
5
  # the root directory of this source tree.
6
6
 
7
7
  import uuid
8
- from typing import TYPE_CHECKING, Any
8
+ from typing import TYPE_CHECKING
9
9
 
10
10
  if TYPE_CHECKING:
11
11
  from codeshield.cs import CodeShieldScanResult
@@ -15,9 +15,11 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
15
15
  interleaved_content_as_str,
16
16
  )
17
17
  from llama_stack_api import (
18
+ GetShieldRequest,
18
19
  ModerationObject,
19
20
  ModerationObjectResults,
20
- OpenAIMessageParam,
21
+ RunModerationRequest,
22
+ RunShieldRequest,
21
23
  RunShieldResponse,
22
24
  Safety,
23
25
  SafetyViolation,
@@ -51,19 +53,14 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
51
53
  f"Unsupported Code Scanner ID: {shield.provider_resource_id}. Allowed IDs: {ALLOWED_CODE_SCANNER_MODEL_IDS}"
52
54
  )
53
55
 
54
- async def run_shield(
55
- self,
56
- shield_id: str,
57
- messages: list[OpenAIMessageParam],
58
- params: dict[str, Any] = None,
59
- ) -> RunShieldResponse:
60
- shield = await self.shield_store.get_shield(shield_id)
56
+ async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse:
57
+ shield = await self.shield_store.get_shield(GetShieldRequest(identifier=request.shield_id))
61
58
  if not shield:
62
- raise ValueError(f"Shield {shield_id} not found")
59
+ raise ValueError(f"Shield {request.shield_id} not found")
63
60
 
64
61
  from codeshield.cs import CodeShield
65
62
 
66
- text = "\n".join([interleaved_content_as_str(m.content) for m in messages])
63
+ text = "\n".join([interleaved_content_as_str(m.content) for m in request.messages])
67
64
  log.info(f"Running CodeScannerShield on {text[50:]}")
68
65
  result = await CodeShield.scan_code(text)
69
66
 
@@ -102,11 +99,11 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
102
99
  metadata=metadata,
103
100
  )
104
101
 
105
- async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
106
- if model is None:
102
+ async def run_moderation(self, request: RunModerationRequest) -> ModerationObject:
103
+ if request.model is None:
107
104
  raise ValueError("Code scanner moderation requires a model identifier.")
108
105
 
109
- inputs = input if isinstance(input, list) else [input]
106
+ inputs = request.input if isinstance(request.input, list) else [request.input]
110
107
  results = []
111
108
 
112
109
  from codeshield.cs import CodeShield
@@ -129,4 +126,4 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
129
126
  )
130
127
  results.append(moderation_result)
131
128
 
132
- return ModerationObject(id=str(uuid.uuid4()), model=model, results=results)
129
+ return ModerationObject(id=str(uuid.uuid4()), model=request.model, results=results)
@@ -7,7 +7,6 @@
7
7
  import re
8
8
  import uuid
9
9
  from string import Template
10
- from typing import Any
11
10
 
12
11
  from llama_stack.core.datatypes import Api
13
12
  from llama_stack.log import get_logger
@@ -17,6 +16,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
17
16
  interleaved_content_as_str,
18
17
  )
19
18
  from llama_stack_api import (
19
+ GetShieldRequest,
20
20
  ImageContentItem,
21
21
  Inference,
22
22
  ModerationObject,
@@ -24,6 +24,8 @@ from llama_stack_api import (
24
24
  OpenAIChatCompletionRequestWithExtraBody,
25
25
  OpenAIMessageParam,
26
26
  OpenAIUserMessageParam,
27
+ RunModerationRequest,
28
+ RunShieldRequest,
27
29
  RunShieldResponse,
28
30
  Safety,
29
31
  SafetyViolation,
@@ -161,17 +163,12 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
161
163
  # The routing table handles the removal from the registry
162
164
  pass
163
165
 
164
- async def run_shield(
165
- self,
166
- shield_id: str,
167
- messages: list[OpenAIMessageParam],
168
- params: dict[str, Any] = None,
169
- ) -> RunShieldResponse:
170
- shield = await self.shield_store.get_shield(shield_id)
166
+ async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse:
167
+ shield = await self.shield_store.get_shield(GetShieldRequest(identifier=request.shield_id))
171
168
  if not shield:
172
- raise ValueError(f"Unknown shield {shield_id}")
169
+ raise ValueError(f"Unknown shield {request.shield_id}")
173
170
 
174
- messages = messages.copy()
171
+ messages = request.messages.copy()
175
172
  # some shields like llama-guard require the first message to be a user message
176
173
  # since this might be a tool call, first role might not be user
177
174
  if len(messages) > 0 and messages[0].role != "user":
@@ -200,30 +197,30 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
200
197
 
201
198
  return await impl.run(messages)
202
199
 
203
- async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
204
- if model is None:
200
+ async def run_moderation(self, request: RunModerationRequest) -> ModerationObject:
201
+ if request.model is None:
205
202
  raise ValueError("Llama Guard moderation requires a model identifier.")
206
203
 
207
- if isinstance(input, list):
208
- messages = input.copy()
204
+ if isinstance(request.input, list):
205
+ messages = request.input.copy()
209
206
  else:
210
- messages = [input]
207
+ messages = [request.input]
211
208
 
212
209
  # convert to user messages format with role
213
210
  messages = [OpenAIUserMessageParam(content=m) for m in messages]
214
211
 
215
212
  # Determine safety categories based on the model type
216
213
  # For known Llama Guard models, use specific categories
217
- if model in LLAMA_GUARD_MODEL_IDS:
214
+ if request.model in LLAMA_GUARD_MODEL_IDS:
218
215
  # Use the mapped model for categories but the original model_id for inference
219
- mapped_model = LLAMA_GUARD_MODEL_IDS[model]
216
+ mapped_model = LLAMA_GUARD_MODEL_IDS[request.model]
220
217
  safety_categories = MODEL_TO_SAFETY_CATEGORIES_MAP.get(mapped_model, DEFAULT_LG_V3_SAFETY_CATEGORIES)
221
218
  else:
222
219
  # For unknown models, use default Llama Guard 3 8B categories
223
220
  safety_categories = DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE]
224
221
 
225
222
  impl = LlamaGuardShield(
226
- model=model,
223
+ model=request.model,
227
224
  inference_api=self.inference_api,
228
225
  excluded_categories=self.config.excluded_categories,
229
226
  safety_categories=safety_categories,
@@ -4,17 +4,19 @@
4
4
  # This source code is licensed under the terms described in the LICENSE file in
5
5
  # the root directory of this source tree.
6
6
 
7
- from typing import Any
8
-
9
7
  import torch
10
8
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
11
9
 
12
10
  from llama_stack.core.utils.model_utils import model_local_dir
13
11
  from llama_stack.log import get_logger
14
- from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
12
+ from llama_stack.providers.utils.inference.prompt_adapter import (
13
+ interleaved_content_as_str,
14
+ )
15
+ from llama_stack.providers.utils.safety import ShieldToModerationMixin
15
16
  from llama_stack_api import (
16
- ModerationObject,
17
+ GetShieldRequest,
17
18
  OpenAIMessageParam,
19
+ RunShieldRequest,
18
20
  RunShieldResponse,
19
21
  Safety,
20
22
  SafetyViolation,
@@ -31,7 +33,7 @@ log = get_logger(name=__name__, category="safety")
31
33
  PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
32
34
 
33
35
 
34
- class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
36
+ class PromptGuardSafetyImpl(ShieldToModerationMixin, Safety, ShieldsProtocolPrivate):
35
37
  shield_store: ShieldStore
36
38
 
37
39
  def __init__(self, config: PromptGuardConfig, _deps) -> None:
@@ -51,20 +53,12 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
51
53
  async def unregister_shield(self, identifier: str) -> None:
52
54
  pass
53
55
 
54
- async def run_shield(
55
- self,
56
- shield_id: str,
57
- messages: list[OpenAIMessageParam],
58
- params: dict[str, Any],
59
- ) -> RunShieldResponse:
60
- shield = await self.shield_store.get_shield(shield_id)
56
+ async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse:
57
+ shield = await self.shield_store.get_shield(GetShieldRequest(identifier=request.shield_id))
61
58
  if not shield:
62
- raise ValueError(f"Unknown shield {shield_id}")
63
-
64
- return await self.shield.run(messages)
59
+ raise ValueError(f"Unknown shield {request.shield_id}")
65
60
 
66
- async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
67
- raise NotImplementedError("run_moderation is not implemented for Prompt Guard")
61
+ return await self.shield.run(request.messages)
68
62
 
69
63
 
70
64
  class PromptGuardShield:
@@ -3,16 +3,17 @@
3
3
  #
4
4
  # This source code is licensed under the terms described in the LICENSE file in
5
5
  # the root directory of this source tree.
6
- from typing import Any
7
6
 
8
7
  from llama_stack_api import (
9
8
  DatasetIO,
10
9
  Datasets,
10
+ IterRowsRequest,
11
+ ScoreBatchRequest,
11
12
  ScoreBatchResponse,
13
+ ScoreRequest,
12
14
  ScoreResponse,
13
15
  Scoring,
14
16
  ScoringFn,
15
- ScoringFnParams,
16
17
  ScoringFunctionsProtocolPrivate,
17
18
  ScoringResult,
18
19
  )
@@ -75,19 +76,15 @@ class BasicScoringImpl(
75
76
 
76
77
  async def score_batch(
77
78
  self,
78
- dataset_id: str,
79
- scoring_functions: dict[str, ScoringFnParams | None] = None,
80
- save_results_dataset: bool = False,
79
+ request: ScoreBatchRequest,
81
80
  ) -> ScoreBatchResponse:
82
- all_rows = await self.datasetio_api.iterrows(
83
- dataset_id=dataset_id,
84
- limit=-1,
85
- )
86
- res = await self.score(
81
+ all_rows = await self.datasetio_api.iterrows(IterRowsRequest(dataset_id=request.dataset_id, limit=-1))
82
+ score_request = ScoreRequest(
87
83
  input_rows=all_rows.data,
88
- scoring_functions=scoring_functions,
84
+ scoring_functions=request.scoring_functions,
89
85
  )
90
- if save_results_dataset:
86
+ res = await self.score(score_request)
87
+ if request.save_results_dataset:
91
88
  # TODO: persist and register dataset on to server for reading
92
89
  # self.datasets_api.register_dataset()
93
90
  raise NotImplementedError("Save results dataset not implemented yet")
@@ -98,16 +95,15 @@ class BasicScoringImpl(
98
95
 
99
96
  async def score(
100
97
  self,
101
- input_rows: list[dict[str, Any]],
102
- scoring_functions: dict[str, ScoringFnParams | None] = None,
98
+ request: ScoreRequest,
103
99
  ) -> ScoreResponse:
104
100
  res = {}
105
- for scoring_fn_id in scoring_functions.keys():
101
+ for scoring_fn_id in request.scoring_functions.keys():
106
102
  if scoring_fn_id not in self.scoring_fn_id_impls:
107
103
  raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
108
104
  scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
109
- scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
110
- score_results = await scoring_fn.score(input_rows, scoring_fn_id, scoring_fn_params)
105
+ scoring_fn_params = request.scoring_functions.get(scoring_fn_id, None)
106
+ score_results = await scoring_fn.score(request.input_rows, scoring_fn_id, scoring_fn_params)
111
107
  agg_results = await scoring_fn.aggregate(score_results, scoring_fn_id, scoring_fn_params)
112
108
  res[scoring_fn_id] = ScoringResult(
113
109
  score_rows=score_results,
@@ -29,11 +29,13 @@ from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metr
29
29
  from llama_stack_api import (
30
30
  DatasetIO,
31
31
  Datasets,
32
+ IterRowsRequest,
33
+ ScoreBatchRequest,
32
34
  ScoreBatchResponse,
35
+ ScoreRequest,
33
36
  ScoreResponse,
34
37
  Scoring,
35
38
  ScoringFn,
36
- ScoringFnParams,
37
39
  ScoringFunctionsProtocolPrivate,
38
40
  ScoringResult,
39
41
  ScoringResultRow,
@@ -158,18 +160,17 @@ class BraintrustScoringImpl(
158
160
 
159
161
  async def score_batch(
160
162
  self,
161
- dataset_id: str,
162
- scoring_functions: dict[str, ScoringFnParams | None],
163
- save_results_dataset: bool = False,
163
+ request: ScoreBatchRequest,
164
164
  ) -> ScoreBatchResponse:
165
165
  await self.set_api_key()
166
166
 
167
- all_rows = await self.datasetio_api.iterrows(
168
- dataset_id=dataset_id,
169
- limit=-1,
167
+ all_rows = await self.datasetio_api.iterrows(IterRowsRequest(dataset_id=request.dataset_id, limit=-1))
168
+ score_request = ScoreRequest(
169
+ input_rows=all_rows.data,
170
+ scoring_functions=request.scoring_functions,
170
171
  )
171
- res = await self.score(input_rows=all_rows.data, scoring_functions=scoring_functions)
172
- if save_results_dataset:
172
+ res = await self.score(score_request)
173
+ if request.save_results_dataset:
173
174
  # TODO: persist and register dataset on to server for reading
174
175
  # self.datasets_api.register_dataset()
175
176
  raise NotImplementedError("Save results dataset not implemented yet")
@@ -198,21 +199,20 @@ class BraintrustScoringImpl(
198
199
 
199
200
  async def score(
200
201
  self,
201
- input_rows: list[dict[str, Any]],
202
- scoring_functions: dict[str, ScoringFnParams | None],
202
+ request: ScoreRequest,
203
203
  ) -> ScoreResponse:
204
204
  await self.set_api_key()
205
205
  res = {}
206
- for scoring_fn_id in scoring_functions:
206
+ for scoring_fn_id in request.scoring_functions:
207
207
  if scoring_fn_id not in self.supported_fn_defs_registry:
208
208
  raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
209
209
 
210
- score_results = [await self.score_row(input_row, scoring_fn_id) for input_row in input_rows]
210
+ score_results = [await self.score_row(input_row, scoring_fn_id) for input_row in request.input_rows]
211
211
  aggregation_functions = self.supported_fn_defs_registry[scoring_fn_id].params.aggregation_functions
212
212
 
213
213
  # override scoring_fn params if provided
214
- if scoring_functions[scoring_fn_id] is not None:
215
- override_params = scoring_functions[scoring_fn_id]
214
+ if request.scoring_functions[scoring_fn_id] is not None:
215
+ override_params = request.scoring_functions[scoring_fn_id]
216
216
  if override_params.aggregation_functions:
217
217
  aggregation_functions = override_params.aggregation_functions
218
218