llama-stack 0.4.3__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (311) hide show
  1. llama_stack/cli/stack/_list_deps.py +11 -7
  2. llama_stack/cli/stack/run.py +3 -25
  3. llama_stack/core/access_control/datatypes.py +78 -0
  4. llama_stack/core/configure.py +2 -2
  5. {llama_stack_api/internal → llama_stack/core/connectors}/__init__.py +2 -2
  6. llama_stack/core/connectors/connectors.py +162 -0
  7. llama_stack/core/conversations/conversations.py +61 -58
  8. llama_stack/core/datatypes.py +54 -8
  9. llama_stack/core/library_client.py +60 -13
  10. llama_stack/core/prompts/prompts.py +43 -42
  11. llama_stack/core/routers/datasets.py +20 -17
  12. llama_stack/core/routers/eval_scoring.py +143 -53
  13. llama_stack/core/routers/inference.py +20 -9
  14. llama_stack/core/routers/safety.py +30 -42
  15. llama_stack/core/routers/vector_io.py +15 -7
  16. llama_stack/core/routing_tables/models.py +42 -3
  17. llama_stack/core/routing_tables/scoring_functions.py +19 -19
  18. llama_stack/core/routing_tables/shields.py +20 -17
  19. llama_stack/core/routing_tables/vector_stores.py +8 -5
  20. llama_stack/core/server/auth.py +192 -17
  21. llama_stack/core/server/fastapi_router_registry.py +40 -5
  22. llama_stack/core/server/server.py +24 -5
  23. llama_stack/core/stack.py +54 -10
  24. llama_stack/core/storage/datatypes.py +9 -0
  25. llama_stack/core/store/registry.py +1 -1
  26. llama_stack/core/utils/exec.py +2 -2
  27. llama_stack/core/utils/type_inspection.py +16 -2
  28. llama_stack/distributions/dell/config.yaml +4 -1
  29. llama_stack/distributions/dell/doc_template.md +209 -0
  30. llama_stack/distributions/dell/run-with-safety.yaml +4 -1
  31. llama_stack/distributions/nvidia/config.yaml +4 -1
  32. llama_stack/distributions/nvidia/doc_template.md +170 -0
  33. llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
  34. llama_stack/distributions/oci/config.yaml +4 -1
  35. llama_stack/distributions/oci/doc_template.md +140 -0
  36. llama_stack/distributions/open-benchmark/config.yaml +9 -1
  37. llama_stack/distributions/postgres-demo/config.yaml +1 -1
  38. llama_stack/distributions/starter/build.yaml +62 -0
  39. llama_stack/distributions/starter/config.yaml +22 -3
  40. llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
  41. llama_stack/distributions/starter/starter.py +13 -1
  42. llama_stack/distributions/starter-gpu/build.yaml +62 -0
  43. llama_stack/distributions/starter-gpu/config.yaml +22 -3
  44. llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
  45. llama_stack/distributions/template.py +10 -2
  46. llama_stack/distributions/watsonx/config.yaml +4 -1
  47. llama_stack/log.py +1 -0
  48. llama_stack/models/llama/resources/dog.jpg +0 -0
  49. llama_stack/models/llama/resources/pasta.jpeg +0 -0
  50. llama_stack/models/llama/resources/small_dog.jpg +0 -0
  51. llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
  52. llama_stack/providers/inline/agents/meta_reference/agents.py +58 -61
  53. llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +187 -60
  54. llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +99 -22
  55. llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
  56. llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
  57. llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
  58. llama_stack/providers/inline/batches/reference/batches.py +2 -1
  59. llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
  60. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.h +9 -0
  61. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.swift +189 -0
  62. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/Parsing.swift +238 -0
  63. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/PromptTemplate.swift +12 -0
  64. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/SystemPrompts.swift +89 -0
  65. llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.pbxproj +550 -0
  66. llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/contents.xcworkspacedata +7 -0
  67. llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +8 -0
  68. llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
  69. llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
  70. llama_stack/providers/inline/post_training/torchtune/common/utils.py +5 -9
  71. llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
  72. llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
  73. llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
  74. llama_stack/providers/inline/safety/llama_guard/llama_guard.py +20 -24
  75. llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
  76. llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
  77. llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
  78. llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
  79. llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
  80. llama_stack/providers/registry/agents.py +1 -0
  81. llama_stack/providers/registry/inference.py +1 -9
  82. llama_stack/providers/registry/vector_io.py +136 -16
  83. llama_stack/providers/remote/datasetio/nvidia/README.md +74 -0
  84. llama_stack/providers/remote/eval/nvidia/README.md +134 -0
  85. llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
  86. llama_stack/providers/remote/files/s3/README.md +266 -0
  87. llama_stack/providers/remote/files/s3/config.py +5 -3
  88. llama_stack/providers/remote/files/s3/files.py +2 -2
  89. llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
  90. llama_stack/providers/remote/inference/nvidia/NVIDIA.md +203 -0
  91. llama_stack/providers/remote/inference/openai/openai.py +2 -0
  92. llama_stack/providers/remote/inference/together/together.py +4 -0
  93. llama_stack/providers/remote/inference/vertexai/config.py +3 -3
  94. llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
  95. llama_stack/providers/remote/inference/vllm/config.py +37 -18
  96. llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
  97. llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
  98. llama_stack/providers/remote/post_training/nvidia/README.md +151 -0
  99. llama_stack/providers/remote/post_training/nvidia/models.py +3 -11
  100. llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
  101. llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
  102. llama_stack/providers/remote/safety/nvidia/README.md +78 -0
  103. llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
  104. llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
  105. llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
  106. llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
  107. llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
  108. llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
  109. llama_stack/providers/remote/vector_io/oci/config.py +41 -0
  110. llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
  111. llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
  112. llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
  113. llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
  114. llama_stack/providers/utils/bedrock/client.py +3 -3
  115. llama_stack/providers/utils/bedrock/config.py +7 -7
  116. llama_stack/providers/utils/inference/__init__.py +0 -25
  117. llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
  118. llama_stack/providers/utils/inference/http_client.py +239 -0
  119. llama_stack/providers/utils/inference/litellm_openai_mixin.py +6 -0
  120. llama_stack/providers/utils/inference/model_registry.py +148 -2
  121. llama_stack/providers/utils/inference/openai_compat.py +1 -158
  122. llama_stack/providers/utils/inference/openai_mixin.py +42 -2
  123. llama_stack/providers/utils/inference/prompt_adapter.py +0 -209
  124. llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
  125. llama_stack/providers/utils/memory/vector_store.py +46 -19
  126. llama_stack/providers/utils/responses/responses_store.py +40 -6
  127. llama_stack/providers/utils/safety.py +114 -0
  128. llama_stack/providers/utils/tools/mcp.py +44 -3
  129. llama_stack/testing/api_recorder.py +9 -3
  130. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/METADATA +14 -2
  131. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/RECORD +135 -279
  132. llama_stack-0.5.0.dist-info/top_level.txt +1 -0
  133. llama_stack/distributions/meta-reference-gpu/__init__.py +0 -7
  134. llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
  135. llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
  136. llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
  137. llama_stack/models/llama/hadamard_utils.py +0 -88
  138. llama_stack/models/llama/llama3/args.py +0 -74
  139. llama_stack/models/llama/llama3/generation.py +0 -378
  140. llama_stack/models/llama/llama3/model.py +0 -304
  141. llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
  142. llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
  143. llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
  144. llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
  145. llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
  146. llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
  147. llama_stack/models/llama/llama3/quantization/loader.py +0 -316
  148. llama_stack/models/llama/llama3_1/__init__.py +0 -12
  149. llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
  150. llama_stack/models/llama/llama3_1/prompts.py +0 -258
  151. llama_stack/models/llama/llama3_2/__init__.py +0 -5
  152. llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
  153. llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
  154. llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
  155. llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
  156. llama_stack/models/llama/llama3_3/__init__.py +0 -5
  157. llama_stack/models/llama/llama3_3/prompts.py +0 -259
  158. llama_stack/models/llama/llama4/args.py +0 -107
  159. llama_stack/models/llama/llama4/ffn.py +0 -58
  160. llama_stack/models/llama/llama4/moe.py +0 -214
  161. llama_stack/models/llama/llama4/preprocess.py +0 -435
  162. llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
  163. llama_stack/models/llama/llama4/quantization/loader.py +0 -226
  164. llama_stack/models/llama/llama4/vision/__init__.py +0 -5
  165. llama_stack/models/llama/llama4/vision/embedding.py +0 -210
  166. llama_stack/models/llama/llama4/vision/encoder.py +0 -412
  167. llama_stack/models/llama/quantize_impls.py +0 -316
  168. llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
  169. llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
  170. llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
  171. llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
  172. llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
  173. llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
  174. llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
  175. llama_stack-0.4.3.dist-info/top_level.txt +0 -2
  176. llama_stack_api/__init__.py +0 -945
  177. llama_stack_api/admin/__init__.py +0 -45
  178. llama_stack_api/admin/api.py +0 -72
  179. llama_stack_api/admin/fastapi_routes.py +0 -117
  180. llama_stack_api/admin/models.py +0 -113
  181. llama_stack_api/agents.py +0 -173
  182. llama_stack_api/batches/__init__.py +0 -40
  183. llama_stack_api/batches/api.py +0 -53
  184. llama_stack_api/batches/fastapi_routes.py +0 -113
  185. llama_stack_api/batches/models.py +0 -78
  186. llama_stack_api/benchmarks/__init__.py +0 -43
  187. llama_stack_api/benchmarks/api.py +0 -39
  188. llama_stack_api/benchmarks/fastapi_routes.py +0 -109
  189. llama_stack_api/benchmarks/models.py +0 -109
  190. llama_stack_api/common/__init__.py +0 -5
  191. llama_stack_api/common/content_types.py +0 -101
  192. llama_stack_api/common/errors.py +0 -95
  193. llama_stack_api/common/job_types.py +0 -38
  194. llama_stack_api/common/responses.py +0 -77
  195. llama_stack_api/common/training_types.py +0 -47
  196. llama_stack_api/common/type_system.py +0 -146
  197. llama_stack_api/connectors.py +0 -146
  198. llama_stack_api/conversations.py +0 -270
  199. llama_stack_api/datasetio.py +0 -55
  200. llama_stack_api/datasets/__init__.py +0 -61
  201. llama_stack_api/datasets/api.py +0 -35
  202. llama_stack_api/datasets/fastapi_routes.py +0 -104
  203. llama_stack_api/datasets/models.py +0 -152
  204. llama_stack_api/datatypes.py +0 -373
  205. llama_stack_api/eval.py +0 -137
  206. llama_stack_api/file_processors/__init__.py +0 -27
  207. llama_stack_api/file_processors/api.py +0 -64
  208. llama_stack_api/file_processors/fastapi_routes.py +0 -78
  209. llama_stack_api/file_processors/models.py +0 -42
  210. llama_stack_api/files/__init__.py +0 -35
  211. llama_stack_api/files/api.py +0 -51
  212. llama_stack_api/files/fastapi_routes.py +0 -124
  213. llama_stack_api/files/models.py +0 -107
  214. llama_stack_api/inference.py +0 -1169
  215. llama_stack_api/inspect_api/__init__.py +0 -37
  216. llama_stack_api/inspect_api/api.py +0 -25
  217. llama_stack_api/inspect_api/fastapi_routes.py +0 -76
  218. llama_stack_api/inspect_api/models.py +0 -28
  219. llama_stack_api/internal/kvstore.py +0 -28
  220. llama_stack_api/internal/sqlstore.py +0 -81
  221. llama_stack_api/llama_stack_api/__init__.py +0 -945
  222. llama_stack_api/llama_stack_api/admin/__init__.py +0 -45
  223. llama_stack_api/llama_stack_api/admin/api.py +0 -72
  224. llama_stack_api/llama_stack_api/admin/fastapi_routes.py +0 -117
  225. llama_stack_api/llama_stack_api/admin/models.py +0 -113
  226. llama_stack_api/llama_stack_api/agents.py +0 -173
  227. llama_stack_api/llama_stack_api/batches/__init__.py +0 -40
  228. llama_stack_api/llama_stack_api/batches/api.py +0 -53
  229. llama_stack_api/llama_stack_api/batches/fastapi_routes.py +0 -113
  230. llama_stack_api/llama_stack_api/batches/models.py +0 -78
  231. llama_stack_api/llama_stack_api/benchmarks/__init__.py +0 -43
  232. llama_stack_api/llama_stack_api/benchmarks/api.py +0 -39
  233. llama_stack_api/llama_stack_api/benchmarks/fastapi_routes.py +0 -109
  234. llama_stack_api/llama_stack_api/benchmarks/models.py +0 -109
  235. llama_stack_api/llama_stack_api/common/__init__.py +0 -5
  236. llama_stack_api/llama_stack_api/common/content_types.py +0 -101
  237. llama_stack_api/llama_stack_api/common/errors.py +0 -95
  238. llama_stack_api/llama_stack_api/common/job_types.py +0 -38
  239. llama_stack_api/llama_stack_api/common/responses.py +0 -77
  240. llama_stack_api/llama_stack_api/common/training_types.py +0 -47
  241. llama_stack_api/llama_stack_api/common/type_system.py +0 -146
  242. llama_stack_api/llama_stack_api/connectors.py +0 -146
  243. llama_stack_api/llama_stack_api/conversations.py +0 -270
  244. llama_stack_api/llama_stack_api/datasetio.py +0 -55
  245. llama_stack_api/llama_stack_api/datasets/__init__.py +0 -61
  246. llama_stack_api/llama_stack_api/datasets/api.py +0 -35
  247. llama_stack_api/llama_stack_api/datasets/fastapi_routes.py +0 -104
  248. llama_stack_api/llama_stack_api/datasets/models.py +0 -152
  249. llama_stack_api/llama_stack_api/datatypes.py +0 -373
  250. llama_stack_api/llama_stack_api/eval.py +0 -137
  251. llama_stack_api/llama_stack_api/file_processors/__init__.py +0 -27
  252. llama_stack_api/llama_stack_api/file_processors/api.py +0 -64
  253. llama_stack_api/llama_stack_api/file_processors/fastapi_routes.py +0 -78
  254. llama_stack_api/llama_stack_api/file_processors/models.py +0 -42
  255. llama_stack_api/llama_stack_api/files/__init__.py +0 -35
  256. llama_stack_api/llama_stack_api/files/api.py +0 -51
  257. llama_stack_api/llama_stack_api/files/fastapi_routes.py +0 -124
  258. llama_stack_api/llama_stack_api/files/models.py +0 -107
  259. llama_stack_api/llama_stack_api/inference.py +0 -1169
  260. llama_stack_api/llama_stack_api/inspect_api/__init__.py +0 -37
  261. llama_stack_api/llama_stack_api/inspect_api/api.py +0 -25
  262. llama_stack_api/llama_stack_api/inspect_api/fastapi_routes.py +0 -76
  263. llama_stack_api/llama_stack_api/inspect_api/models.py +0 -28
  264. llama_stack_api/llama_stack_api/internal/__init__.py +0 -9
  265. llama_stack_api/llama_stack_api/internal/kvstore.py +0 -28
  266. llama_stack_api/llama_stack_api/internal/sqlstore.py +0 -81
  267. llama_stack_api/llama_stack_api/models.py +0 -171
  268. llama_stack_api/llama_stack_api/openai_responses.py +0 -1468
  269. llama_stack_api/llama_stack_api/post_training.py +0 -370
  270. llama_stack_api/llama_stack_api/prompts.py +0 -203
  271. llama_stack_api/llama_stack_api/providers/__init__.py +0 -33
  272. llama_stack_api/llama_stack_api/providers/api.py +0 -16
  273. llama_stack_api/llama_stack_api/providers/fastapi_routes.py +0 -57
  274. llama_stack_api/llama_stack_api/providers/models.py +0 -24
  275. llama_stack_api/llama_stack_api/py.typed +0 -0
  276. llama_stack_api/llama_stack_api/rag_tool.py +0 -168
  277. llama_stack_api/llama_stack_api/resource.py +0 -37
  278. llama_stack_api/llama_stack_api/router_utils.py +0 -160
  279. llama_stack_api/llama_stack_api/safety.py +0 -132
  280. llama_stack_api/llama_stack_api/schema_utils.py +0 -208
  281. llama_stack_api/llama_stack_api/scoring.py +0 -93
  282. llama_stack_api/llama_stack_api/scoring_functions.py +0 -211
  283. llama_stack_api/llama_stack_api/shields.py +0 -93
  284. llama_stack_api/llama_stack_api/tools.py +0 -226
  285. llama_stack_api/llama_stack_api/vector_io.py +0 -941
  286. llama_stack_api/llama_stack_api/vector_stores.py +0 -53
  287. llama_stack_api/llama_stack_api/version.py +0 -9
  288. llama_stack_api/models.py +0 -171
  289. llama_stack_api/openai_responses.py +0 -1468
  290. llama_stack_api/post_training.py +0 -370
  291. llama_stack_api/prompts.py +0 -203
  292. llama_stack_api/providers/__init__.py +0 -33
  293. llama_stack_api/providers/api.py +0 -16
  294. llama_stack_api/providers/fastapi_routes.py +0 -57
  295. llama_stack_api/providers/models.py +0 -24
  296. llama_stack_api/py.typed +0 -0
  297. llama_stack_api/rag_tool.py +0 -168
  298. llama_stack_api/resource.py +0 -37
  299. llama_stack_api/router_utils.py +0 -160
  300. llama_stack_api/safety.py +0 -132
  301. llama_stack_api/schema_utils.py +0 -208
  302. llama_stack_api/scoring.py +0 -93
  303. llama_stack_api/scoring_functions.py +0 -211
  304. llama_stack_api/shields.py +0 -93
  305. llama_stack_api/tools.py +0 -226
  306. llama_stack_api/vector_io.py +0 -941
  307. llama_stack_api/vector_stores.py +0 -53
  308. llama_stack_api/version.py +0 -9
  309. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/WHEEL +0 -0
  310. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/entry_points.txt +0 -0
  311. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -40,9 +40,12 @@ class VertexAIInferenceAdapter(OpenAIMixin):
40
40
  Get the Vertex AI OpenAI-compatible API base URL.
41
41
 
42
42
  Returns the Vertex AI OpenAI-compatible endpoint URL.
43
- Source: https://cloud.google.com/vertex-ai/generative-ai/docs/start/openai
43
+ Source: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/start/openai
44
44
  """
45
- return f"https://{self.config.location}-aiplatform.googleapis.com/v1/projects/{self.config.project}/locations/{self.config.location}/endpoints/openapi"
45
+ if not self.config.location or self.config.location == "global":
46
+ return f"https://aiplatform.googleapis.com/v1/projects/{self.config.project}/locations/global/endpoints/openapi"
47
+ else:
48
+ return f"https://{self.config.location}-aiplatform.googleapis.com/v1/projects/{self.config.project}/locations/{self.config.location}/endpoints/openapi"
46
49
 
47
50
  async def list_provider_model_ids(self) -> Iterable[str]:
48
51
  """
@@ -4,11 +4,16 @@
4
4
  # This source code is licensed under the terms described in the LICENSE file in
5
5
  # the root directory of this source tree.
6
6
 
7
+ import warnings
7
8
  from pathlib import Path
8
9
 
9
- from pydantic import Field, HttpUrl, SecretStr, field_validator
10
+ from pydantic import Field, HttpUrl, SecretStr, model_validator
10
11
 
11
- from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
12
+ from llama_stack.providers.utils.inference.model_registry import (
13
+ NetworkConfig,
14
+ RemoteInferenceProviderConfig,
15
+ TLSConfig,
16
+ )
12
17
  from llama_stack_api import json_schema_type
13
18
 
14
19
 
@@ -27,23 +32,33 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
27
32
  alias="api_token",
28
33
  description="The API token",
29
34
  )
30
- tls_verify: bool | str = Field(
31
- default=True,
32
- description="Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file.",
35
+ tls_verify: bool | str | None = Field(
36
+ default=None,
37
+ deprecated=True,
38
+ description="DEPRECATED: Use 'network.tls.verify' instead. Whether to verify TLS certificates. "
39
+ "Can be a boolean or a path to a CA certificate file.",
33
40
  )
34
41
 
35
- @field_validator("tls_verify")
36
- @classmethod
37
- def validate_tls_verify(cls, v):
38
- if isinstance(v, str):
39
- # Otherwise, treat it as a cert path
40
- cert_path = Path(v).expanduser().resolve()
41
- if not cert_path.exists():
42
- raise ValueError(f"TLS certificate file does not exist: {v}")
43
- if not cert_path.is_file():
44
- raise ValueError(f"TLS certificate path is not a file: {v}")
45
- return v
46
- return v
42
+ @model_validator(mode="after")
43
+ def migrate_tls_verify_to_network(self) -> "VLLMInferenceAdapterConfig":
44
+ """Migrate legacy tls_verify to network.tls.verify for backward compatibility."""
45
+ if self.tls_verify is not None:
46
+ warnings.warn(
47
+ "The 'tls_verify' config option is deprecated. Please use 'network.tls.verify' instead.",
48
+ DeprecationWarning,
49
+ stacklevel=2,
50
+ )
51
+ # Convert string path to Path if needed
52
+ if isinstance(self.tls_verify, str):
53
+ verify_value: bool | Path = Path(self.tls_verify)
54
+ else:
55
+ verify_value = self.tls_verify
56
+
57
+ if self.network is None:
58
+ self.network = NetworkConfig(tls=TLSConfig(verify=verify_value))
59
+ elif self.network.tls is None:
60
+ self.network.tls = TLSConfig(verify=verify_value)
61
+ return self
47
62
 
48
63
  @classmethod
49
64
  def sample_run_config(
@@ -55,5 +70,9 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
55
70
  "base_url": base_url,
56
71
  "max_tokens": "${env.VLLM_MAX_TOKENS:=4096}",
57
72
  "api_token": "${env.VLLM_API_TOKEN:=fake}",
58
- "tls_verify": "${env.VLLM_TLS_VERIFY:=true}",
73
+ "network": {
74
+ "tls": {
75
+ "verify": "${env.VLLM_TLS_VERIFY:=true}",
76
+ },
77
+ },
59
78
  }
@@ -73,9 +73,6 @@ class VLLMInferenceAdapter(OpenAIMixin):
73
73
  except Exception as e:
74
74
  return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
75
75
 
76
- def get_extra_client_params(self):
77
- return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)}
78
-
79
76
  async def check_model_availability(self, model: str) -> bool:
80
77
  """
81
78
  Skip the check when running without authentication.
@@ -23,6 +23,7 @@ from llama_stack_api import (
23
23
  OpenAICompletionRequestWithExtraBody,
24
24
  OpenAIEmbeddingsRequestWithExtraBody,
25
25
  OpenAIEmbeddingsResponse,
26
+ validate_embeddings_input_is_text,
26
27
  )
27
28
 
28
29
  logger = get_logger(name=__name__, category="providers::remote::watsonx")
@@ -147,6 +148,9 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
147
148
  """
148
149
  Override parent method to add watsonx-specific parameters.
149
150
  """
151
+ # Validate that input contains only text, not token arrays
152
+ validate_embeddings_input_is_text(params)
153
+
150
154
  model_obj = await self.model_store.get_model(params.model)
151
155
 
152
156
  # Convert input to list if it's a string
@@ -0,0 +1,151 @@
1
+ # NVIDIA Post-Training Provider for LlamaStack
2
+
3
+ This provider enables fine-tuning of LLMs using NVIDIA's NeMo Customizer service.
4
+
5
+ ## Features
6
+
7
+ - Supervised fine-tuning of Llama models
8
+ - LoRA fine-tuning support
9
+ - Job management and status tracking
10
+
11
+ ## Getting Started
12
+
13
+ ### Prerequisites
14
+
15
+ - LlamaStack with NVIDIA configuration
16
+ - Access to Hosted NVIDIA NeMo Customizer service
17
+ - Dataset registered in the Hosted NVIDIA NeMo Customizer service
18
+ - Base model downloaded and available in the Hosted NVIDIA NeMo Customizer service
19
+
20
+ ### Setup
21
+
22
+ Build the NVIDIA environment:
23
+
24
+ ```bash
25
+ uv pip install llama-stack-client
26
+ uv run llama stack list-deps nvidia | xargs -L1 uv pip install
27
+ ```
28
+
29
+ ### Basic Usage using the LlamaStack Python Client
30
+
31
+ ### Create Customization Job
32
+
33
+ #### Initialize the client
34
+
35
+ ```python
36
+ import os
37
+
38
+ os.environ["NVIDIA_API_KEY"] = "your-api-key"
39
+ os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
40
+ os.environ["NVIDIA_DATASET_NAMESPACE"] = "default"
41
+ os.environ["NVIDIA_PROJECT_ID"] = "test-project"
42
+ os.environ["NVIDIA_OUTPUT_MODEL_DIR"] = "test-example-model@v1"
43
+
44
+ from llama_stack.core.library_client import LlamaStackAsLibraryClient
45
+
46
+ client = LlamaStackAsLibraryClient("nvidia")
47
+ client.initialize()
48
+ ```
49
+
50
+ #### Configure fine-tuning parameters
51
+
52
+ ```python
53
+ from llama_stack_client.types.post_training_supervised_fine_tune_params import (
54
+ TrainingConfig,
55
+ TrainingConfigDataConfig,
56
+ TrainingConfigOptimizerConfig,
57
+ )
58
+ from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig
59
+ ```
60
+
61
+ #### Set up LoRA configuration
62
+
63
+ ```python
64
+ algorithm_config = LoraFinetuningConfig(type="LoRA", adapter_dim=16)
65
+ ```
66
+
67
+ #### Configure training data
68
+
69
+ ```python
70
+ data_config = TrainingConfigDataConfig(
71
+ dataset_id="your-dataset-id", # Use client.datasets.list() to see available datasets
72
+ batch_size=16,
73
+ )
74
+ ```
75
+
76
+ #### Configure optimizer
77
+
78
+ ```python
79
+ optimizer_config = TrainingConfigOptimizerConfig(
80
+ lr=0.0001,
81
+ )
82
+ ```
83
+
84
+ #### Set up training configuration
85
+
86
+ ```python
87
+ training_config = TrainingConfig(
88
+ n_epochs=2,
89
+ data_config=data_config,
90
+ optimizer_config=optimizer_config,
91
+ )
92
+ ```
93
+
94
+ #### Start fine-tuning job
95
+
96
+ ```python
97
+ training_job = client.post_training.supervised_fine_tune(
98
+ job_uuid="unique-job-id",
99
+ model="meta-llama/Llama-3.1-8B-Instruct",
100
+ checkpoint_dir="",
101
+ algorithm_config=algorithm_config,
102
+ training_config=training_config,
103
+ logger_config={},
104
+ hyperparam_search_config={},
105
+ )
106
+ ```
107
+
108
+ ### List all jobs
109
+
110
+ ```python
111
+ jobs = client.post_training.job.list()
112
+ ```
113
+
114
+ ### Check job status
115
+
116
+ ```python
117
+ job_status = client.post_training.job.status(job_uuid="your-job-id")
118
+ ```
119
+
120
+ ### Cancel a job
121
+
122
+ ```python
123
+ client.post_training.job.cancel(job_uuid="your-job-id")
124
+ ```
125
+
126
+ ### Inference with the fine-tuned model
127
+
128
+ #### 1. Register the model
129
+
130
+ ```python
131
+ from llama_stack_api.models import Model, ModelType
132
+
133
+ client.models.register(
134
+ model_id="test-example-model@v1",
135
+ provider_id="nvidia",
136
+ provider_model_id="test-example-model@v1",
137
+ model_type=ModelType.llm,
138
+ )
139
+ ```
140
+
141
+ #### 2. Inference with the fine-tuned model
142
+
143
+ ```python
144
+ response = client.completions.create(
145
+ prompt="Complete the sentence using one word: Roses are red, violets are ",
146
+ stream=False,
147
+ model="test-example-model@v1",
148
+ max_tokens=50,
149
+ )
150
+ print(response.choices[0].text)
151
+ ```
@@ -5,23 +5,15 @@
5
5
  # the root directory of this source tree.
6
6
 
7
7
 
8
- from llama_stack.models.llama.sku_types import CoreModelId
9
- from llama_stack.providers.utils.inference.model_registry import (
10
- ProviderModelEntry,
11
- build_hf_repo_model_entry,
12
- )
8
+ from llama_stack.providers.utils.inference.model_registry import build_hf_repo_model_entry
13
9
 
14
10
  _MODEL_ENTRIES = [
15
11
  build_hf_repo_model_entry(
16
12
  "meta/llama-3.1-8b-instruct",
17
- CoreModelId.llama3_1_8b_instruct.value,
13
+ "Llama3.1-8B-Instruct",
18
14
  ),
19
15
  build_hf_repo_model_entry(
20
16
  "meta/llama-3.2-1b-instruct",
21
- CoreModelId.llama3_2_1b_instruct.value,
17
+ "Llama3.2-1B-Instruct",
22
18
  ),
23
19
  ]
24
-
25
-
26
- def get_model_entries() -> list[ProviderModelEntry]:
27
- return _MODEL_ENTRIES
@@ -14,13 +14,15 @@ from llama_stack.providers.remote.post_training.nvidia.config import NvidiaPostT
14
14
  from llama_stack.providers.remote.post_training.nvidia.utils import warn_unsupported_params
15
15
  from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
16
16
  from llama_stack_api import (
17
- AlgorithmConfig,
18
- DPOAlignmentConfig,
17
+ CancelTrainingJobRequest,
18
+ GetTrainingJobArtifactsRequest,
19
+ GetTrainingJobStatusRequest,
19
20
  JobStatus,
20
21
  PostTrainingJob,
21
22
  PostTrainingJobArtifactsResponse,
22
23
  PostTrainingJobStatusResponse,
23
- TrainingConfig,
24
+ PreferenceOptimizeRequest,
25
+ SupervisedFineTuneRequest,
24
26
  )
25
27
 
26
28
  from .models import _MODEL_ENTRIES
@@ -156,7 +158,9 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
156
158
 
157
159
  return ListNvidiaPostTrainingJobs(data=jobs)
158
160
 
159
- async def get_training_job_status(self, job_uuid: str) -> NvidiaPostTrainingJobStatusResponse:
161
+ async def get_training_job_status(
162
+ self, request: GetTrainingJobStatusRequest
163
+ ) -> NvidiaPostTrainingJobStatusResponse:
160
164
  """Get the status of a customization job.
161
165
  Updated the base class return type from PostTrainingJobResponse to NvidiaPostTrainingJob.
162
166
 
@@ -178,8 +182,8 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
178
182
  """
179
183
  response = await self._make_request(
180
184
  "GET",
181
- f"/v1/customization/jobs/{job_uuid}/status",
182
- params={"job_id": job_uuid},
185
+ f"/v1/customization/jobs/{request.job_uuid}/status",
186
+ params={"job_id": request.job_uuid},
183
187
  )
184
188
 
185
189
  api_status = response.pop("status").lower()
@@ -187,18 +191,20 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
187
191
 
188
192
  return NvidiaPostTrainingJobStatusResponse(
189
193
  status=JobStatus(mapped_status),
190
- job_uuid=job_uuid,
194
+ job_uuid=request.job_uuid,
191
195
  started_at=datetime.fromisoformat(response.pop("created_at")),
192
196
  updated_at=datetime.fromisoformat(response.pop("updated_at")),
193
197
  **response,
194
198
  )
195
199
 
196
- async def cancel_training_job(self, job_uuid: str) -> None:
200
+ async def cancel_training_job(self, request: CancelTrainingJobRequest) -> None:
197
201
  await self._make_request(
198
- method="POST", path=f"/v1/customization/jobs/{job_uuid}/cancel", params={"job_id": job_uuid}
202
+ method="POST", path=f"/v1/customization/jobs/{request.job_uuid}/cancel", params={"job_id": request.job_uuid}
199
203
  )
200
204
 
201
- async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
205
+ async def get_training_job_artifacts(
206
+ self, request: GetTrainingJobArtifactsRequest
207
+ ) -> PostTrainingJobArtifactsResponse:
202
208
  raise NotImplementedError("Job artifacts are not implemented yet")
203
209
 
204
210
  async def get_post_training_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
@@ -206,13 +212,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
206
212
 
207
213
  async def supervised_fine_tune(
208
214
  self,
209
- job_uuid: str,
210
- training_config: dict[str, Any],
211
- hyperparam_search_config: dict[str, Any],
212
- logger_config: dict[str, Any],
213
- model: str,
214
- checkpoint_dir: str | None,
215
- algorithm_config: AlgorithmConfig | None = None,
215
+ request: SupervisedFineTuneRequest,
216
216
  ) -> NvidiaPostTrainingJob:
217
217
  """
218
218
  Fine-tunes a model on a dataset.
@@ -300,13 +300,16 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
300
300
  User is informed about unsupported parameters via warnings.
301
301
  """
302
302
 
303
+ # Convert training_config to dict for internal processing
304
+ training_config = request.training_config.model_dump()
305
+
303
306
  # Check for unsupported method parameters
304
307
  unsupported_method_params = []
305
- if checkpoint_dir:
306
- unsupported_method_params.append(f"checkpoint_dir={checkpoint_dir}")
307
- if hyperparam_search_config:
308
+ if request.checkpoint_dir:
309
+ unsupported_method_params.append(f"checkpoint_dir={request.checkpoint_dir}")
310
+ if request.hyperparam_search_config:
308
311
  unsupported_method_params.append("hyperparam_search_config")
309
- if logger_config:
312
+ if request.logger_config:
310
313
  unsupported_method_params.append("logger_config")
311
314
 
312
315
  if unsupported_method_params:
@@ -344,7 +347,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
344
347
 
345
348
  # Prepare base job configuration
346
349
  job_config = {
347
- "config": model,
350
+ "config": request.model,
348
351
  "dataset": {
349
352
  "name": training_config["data_config"]["dataset_id"],
350
353
  "namespace": self.config.dataset_namespace,
@@ -388,14 +391,14 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
388
391
  job_config["hyperparameters"].pop("sft")
389
392
 
390
393
  # Handle LoRA-specific configuration
391
- if algorithm_config:
392
- if algorithm_config.type == "LoRA":
393
- warn_unsupported_params(algorithm_config, supported_params["lora_config"], "LoRA config")
394
+ if request.algorithm_config:
395
+ if request.algorithm_config.type == "LoRA":
396
+ warn_unsupported_params(request.algorithm_config, supported_params["lora_config"], "LoRA config")
394
397
  job_config["hyperparameters"]["lora"] = {
395
- k: v for k, v in {"alpha": algorithm_config.alpha}.items() if v is not None
398
+ k: v for k, v in {"alpha": request.algorithm_config.alpha}.items() if v is not None
396
399
  }
397
400
  else:
398
- raise NotImplementedError(f"Unsupported algorithm config: {algorithm_config}")
401
+ raise NotImplementedError(f"Unsupported algorithm config: {request.algorithm_config}")
399
402
 
400
403
  # Create the customization job
401
404
  response = await self._make_request(
@@ -416,12 +419,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
416
419
 
417
420
  async def preference_optimize(
418
421
  self,
419
- job_uuid: str,
420
- finetuned_model: str,
421
- algorithm_config: DPOAlignmentConfig,
422
- training_config: TrainingConfig,
423
- hyperparam_search_config: dict[str, Any],
424
- logger_config: dict[str, Any],
422
+ request: PreferenceOptimizeRequest,
425
423
  ) -> PostTrainingJob:
426
424
  """Optimize a model based on preference data."""
427
425
  raise NotImplementedError("Preference optimization is not implemented yet")
@@ -5,12 +5,13 @@
5
5
  # the root directory of this source tree.
6
6
 
7
7
  import json
8
- from typing import Any
9
8
 
10
9
  from llama_stack.log import get_logger
11
10
  from llama_stack.providers.utils.bedrock.client import create_bedrock_client
11
+ from llama_stack.providers.utils.safety import ShieldToModerationMixin
12
12
  from llama_stack_api import (
13
- OpenAIMessageParam,
13
+ GetShieldRequest,
14
+ RunShieldRequest,
14
15
  RunShieldResponse,
15
16
  Safety,
16
17
  SafetyViolation,
@@ -24,7 +25,7 @@ from .config import BedrockSafetyConfig
24
25
  logger = get_logger(name=__name__, category="safety::bedrock")
25
26
 
26
27
 
27
- class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
28
+ class BedrockSafetyAdapter(ShieldToModerationMixin, Safety, ShieldsProtocolPrivate):
28
29
  def __init__(self, config: BedrockSafetyConfig) -> None:
29
30
  self.config = config
30
31
  self.registered_shields = []
@@ -55,49 +56,31 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
55
56
  async def unregister_shield(self, identifier: str) -> None:
56
57
  pass
57
58
 
58
- async def run_shield(
59
- self, shield_id: str, messages: list[OpenAIMessageParam], params: dict[str, Any] = None
60
- ) -> RunShieldResponse:
61
- shield = await self.shield_store.get_shield(shield_id)
59
+ async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse:
60
+ shield = await self.shield_store.get_shield(GetShieldRequest(identifier=request.shield_id))
62
61
  if not shield:
63
- raise ValueError(f"Shield {shield_id} not found")
64
-
65
- """
66
- This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
67
- ```content = [
68
- {
69
- "text": {
70
- "text": "Is the AB503 Product a better investment than the S&P 500?"
71
- }
72
- }
73
- ]```
74
- Incoming messages contain content, role . For now we will extract the content and
75
- default the "qualifiers": ["query"]
76
- """
62
+ raise ValueError(f"Shield {request.shield_id} not found")
77
63
 
78
64
  shield_params = shield.params
79
- logger.debug(f"run_shield::{shield_params}::messages={messages}")
65
+ logger.debug(f"run_shield::{shield_params}::messages={request.messages}")
80
66
 
81
- # - convert the messages into format Bedrock expects
82
67
  content_messages = []
83
- for message in messages:
68
+ for message in request.messages:
84
69
  content_messages.append({"text": {"text": message.content}})
85
70
  logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
86
71
 
87
72
  response = self.bedrock_runtime_client.apply_guardrail(
88
73
  guardrailIdentifier=shield.provider_resource_id,
89
74
  guardrailVersion=shield_params["guardrailVersion"],
90
- source="OUTPUT", # or 'INPUT' depending on your use case
75
+ source="OUTPUT",
91
76
  content=content_messages,
92
77
  )
93
78
  if response["action"] == "GUARDRAIL_INTERVENED":
94
79
  user_message = ""
95
80
  metadata = {}
96
81
  for output in response["outputs"]:
97
- # guardrails returns a list - however for this implementation we will leverage the last values
98
82
  user_message = output["text"]
99
83
  for assessment in response["assessments"]:
100
- # guardrails returns a list - however for this implementation we will leverage the last values
101
84
  metadata = dict(assessment)
102
85
 
103
86
  return RunShieldResponse(
@@ -0,0 +1,78 @@
1
+ # NVIDIA Safety Provider for LlamaStack
2
+
3
+ This provider enables safety checks and guardrails for LLM interactions using NVIDIA's NeMo Guardrails service.
4
+
5
+ ## Features
6
+
7
+ - Run safety checks for messages
8
+
9
+ ## Getting Started
10
+
11
+ ### Prerequisites
12
+
13
+ - LlamaStack with NVIDIA configuration
14
+ - Access to NVIDIA NeMo Guardrails service
15
+ - NIM for model to use for safety check is deployed
16
+
17
+ ### Setup
18
+
19
+ Build the NVIDIA environment:
20
+
21
+ ```bash
22
+ uv pip install llama-stack-client
23
+ uv run llama stack list-deps nvidia | xargs -L1 uv pip install
24
+ ```
25
+
26
+ ### Basic Usage using the LlamaStack Python Client
27
+
28
+ #### Initialize the client
29
+
30
+ ```python
31
+ import os
32
+
33
+ os.environ["NVIDIA_API_KEY"] = "your-api-key"
34
+ os.environ["NVIDIA_GUARDRAILS_URL"] = "http://guardrails.test"
35
+
36
+ from llama_stack.core.library_client import LlamaStackAsLibraryClient
37
+
38
+ client = LlamaStackAsLibraryClient("nvidia")
39
+ client.initialize()
40
+ ```
41
+
42
+ #### Create a safety shield
43
+
44
+ ```python
45
+ from llama_stack_api.safety import Shield
46
+ from llama_stack_api.inference import Message
47
+
48
+ # Create a safety shield
49
+ shield = Shield(
50
+ shield_id="your-shield-id",
51
+ provider_resource_id="safety-model-id", # The model to use for safety checks
52
+ description="Safety checks for content moderation",
53
+ )
54
+
55
+ # Register the shield
56
+ await client.safety.register_shield(shield)
57
+ ```
58
+
59
+ #### Run safety checks
60
+
61
+ ```python
62
+ # Messages to check
63
+ messages = [Message(role="user", content="Your message to check")]
64
+
65
+ # Run safety check
66
+ response = await client.safety.run_shield(
67
+ shield_id="your-shield-id",
68
+ messages=messages,
69
+ )
70
+
71
+ # Check for violations
72
+ if response.violation:
73
+ print(f"Safety violation detected: {response.violation.user_message}")
74
+ print(f"Violation level: {response.violation.violation_level}")
75
+ print(f"Metadata: {response.violation.metadata}")
76
+ else:
77
+ print("No safety violations detected")
78
+ ```
@@ -9,9 +9,11 @@ from typing import Any
9
9
  import requests
10
10
 
11
11
  from llama_stack.log import get_logger
12
+ from llama_stack.providers.utils.safety import ShieldToModerationMixin
12
13
  from llama_stack_api import (
13
- ModerationObject,
14
+ GetShieldRequest,
14
15
  OpenAIMessageParam,
16
+ RunShieldRequest,
15
17
  RunShieldResponse,
16
18
  Safety,
17
19
  SafetyViolation,
@@ -25,7 +27,7 @@ from .config import NVIDIASafetyConfig
25
27
  logger = get_logger(name=__name__, category="safety::nvidia")
26
28
 
27
29
 
28
- class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
30
+ class NVIDIASafetyAdapter(ShieldToModerationMixin, Safety, ShieldsProtocolPrivate):
29
31
  def __init__(self, config: NVIDIASafetyConfig) -> None:
30
32
  """
31
33
  Initialize the NVIDIASafetyAdapter with a given safety configuration.
@@ -48,32 +50,14 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
48
50
  async def unregister_shield(self, identifier: str) -> None:
49
51
  pass
50
52
 
51
- async def run_shield(
52
- self, shield_id: str, messages: list[OpenAIMessageParam], params: dict[str, Any] | None = None
53
- ) -> RunShieldResponse:
54
- """
55
- Run a safety shield check against the provided messages.
56
-
57
- Args:
58
- shield_id (str): The unique identifier for the shield to be used.
59
- messages (List[Message]): A list of Message objects representing the conversation history.
60
- params (Optional[dict[str, Any]]): Additional parameters for the shield check.
61
-
62
- Returns:
63
- RunShieldResponse: The response containing safety violation details if any.
64
-
65
- Raises:
66
- ValueError: If the shield with the provided shield_id is not found.
67
- """
68
- shield = await self.shield_store.get_shield(shield_id)
53
+ async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse:
54
+ """Run a safety shield check against the provided messages."""
55
+ shield = await self.shield_store.get_shield(GetShieldRequest(identifier=request.shield_id))
69
56
  if not shield:
70
- raise ValueError(f"Shield {shield_id} not found")
57
+ raise ValueError(f"Shield {request.shield_id} not found")
71
58
 
72
59
  self.shield = NeMoGuardrails(self.config, shield.shield_id)
73
- return await self.shield.run(messages)
74
-
75
- async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
76
- raise NotImplementedError("NVIDIA safety provider currently does not implement run_moderation")
60
+ return await self.shield.run(request.messages)
77
61
 
78
62
 
79
63
  class NeMoGuardrails: