llama-stack 0.4.3__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (307) hide show
  1. llama_stack/cli/stack/_list_deps.py +11 -7
  2. llama_stack/cli/stack/run.py +3 -25
  3. llama_stack/core/access_control/datatypes.py +78 -0
  4. llama_stack/core/configure.py +2 -2
  5. {llama_stack_api/internal → llama_stack/core/connectors}/__init__.py +2 -2
  6. llama_stack/core/connectors/connectors.py +162 -0
  7. llama_stack/core/conversations/conversations.py +61 -58
  8. llama_stack/core/datatypes.py +54 -8
  9. llama_stack/core/library_client.py +60 -13
  10. llama_stack/core/prompts/prompts.py +43 -42
  11. llama_stack/core/routers/datasets.py +20 -17
  12. llama_stack/core/routers/eval_scoring.py +143 -53
  13. llama_stack/core/routers/inference.py +20 -9
  14. llama_stack/core/routers/safety.py +30 -42
  15. llama_stack/core/routers/vector_io.py +15 -7
  16. llama_stack/core/routing_tables/models.py +42 -3
  17. llama_stack/core/routing_tables/scoring_functions.py +19 -19
  18. llama_stack/core/routing_tables/shields.py +20 -17
  19. llama_stack/core/routing_tables/vector_stores.py +8 -5
  20. llama_stack/core/server/auth.py +192 -17
  21. llama_stack/core/server/fastapi_router_registry.py +40 -5
  22. llama_stack/core/server/server.py +24 -5
  23. llama_stack/core/stack.py +54 -10
  24. llama_stack/core/storage/datatypes.py +9 -0
  25. llama_stack/core/store/registry.py +1 -1
  26. llama_stack/core/utils/exec.py +2 -2
  27. llama_stack/core/utils/type_inspection.py +16 -2
  28. llama_stack/distributions/dell/config.yaml +4 -1
  29. llama_stack/distributions/dell/doc_template.md +209 -0
  30. llama_stack/distributions/dell/run-with-safety.yaml +4 -1
  31. llama_stack/distributions/nvidia/config.yaml +4 -1
  32. llama_stack/distributions/nvidia/doc_template.md +170 -0
  33. llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
  34. llama_stack/distributions/oci/config.yaml +4 -1
  35. llama_stack/distributions/oci/doc_template.md +140 -0
  36. llama_stack/distributions/open-benchmark/config.yaml +9 -1
  37. llama_stack/distributions/postgres-demo/config.yaml +1 -1
  38. llama_stack/distributions/starter/build.yaml +62 -0
  39. llama_stack/distributions/starter/config.yaml +22 -3
  40. llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
  41. llama_stack/distributions/starter/starter.py +13 -1
  42. llama_stack/distributions/starter-gpu/build.yaml +62 -0
  43. llama_stack/distributions/starter-gpu/config.yaml +22 -3
  44. llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
  45. llama_stack/distributions/template.py +10 -2
  46. llama_stack/distributions/watsonx/config.yaml +4 -1
  47. llama_stack/log.py +1 -0
  48. llama_stack/models/llama/resources/dog.jpg +0 -0
  49. llama_stack/models/llama/resources/pasta.jpeg +0 -0
  50. llama_stack/models/llama/resources/small_dog.jpg +0 -0
  51. llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
  52. llama_stack/providers/inline/agents/meta_reference/agents.py +57 -61
  53. llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +183 -60
  54. llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +94 -22
  55. llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
  56. llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
  57. llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
  58. llama_stack/providers/inline/batches/reference/batches.py +2 -1
  59. llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
  60. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.h +9 -0
  61. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.swift +189 -0
  62. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/Parsing.swift +238 -0
  63. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/PromptTemplate.swift +12 -0
  64. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/SystemPrompts.swift +89 -0
  65. llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.pbxproj +550 -0
  66. llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/contents.xcworkspacedata +7 -0
  67. llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +8 -0
  68. llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
  69. llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
  70. llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
  71. llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
  72. llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
  73. llama_stack/providers/inline/safety/llama_guard/llama_guard.py +15 -18
  74. llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
  75. llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
  76. llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
  77. llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
  78. llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
  79. llama_stack/providers/registry/agents.py +1 -0
  80. llama_stack/providers/registry/inference.py +1 -9
  81. llama_stack/providers/registry/vector_io.py +136 -16
  82. llama_stack/providers/remote/datasetio/nvidia/README.md +74 -0
  83. llama_stack/providers/remote/eval/nvidia/README.md +134 -0
  84. llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
  85. llama_stack/providers/remote/files/s3/README.md +266 -0
  86. llama_stack/providers/remote/files/s3/config.py +5 -3
  87. llama_stack/providers/remote/files/s3/files.py +2 -2
  88. llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
  89. llama_stack/providers/remote/inference/nvidia/NVIDIA.md +203 -0
  90. llama_stack/providers/remote/inference/openai/openai.py +2 -0
  91. llama_stack/providers/remote/inference/together/together.py +4 -0
  92. llama_stack/providers/remote/inference/vertexai/config.py +3 -3
  93. llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
  94. llama_stack/providers/remote/inference/vllm/config.py +37 -18
  95. llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
  96. llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
  97. llama_stack/providers/remote/post_training/nvidia/README.md +151 -0
  98. llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
  99. llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
  100. llama_stack/providers/remote/safety/nvidia/README.md +78 -0
  101. llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
  102. llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
  103. llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
  104. llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
  105. llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
  106. llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
  107. llama_stack/providers/remote/vector_io/oci/config.py +41 -0
  108. llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
  109. llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
  110. llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
  111. llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
  112. llama_stack/providers/utils/bedrock/client.py +3 -3
  113. llama_stack/providers/utils/bedrock/config.py +7 -7
  114. llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
  115. llama_stack/providers/utils/inference/http_client.py +239 -0
  116. llama_stack/providers/utils/inference/litellm_openai_mixin.py +5 -0
  117. llama_stack/providers/utils/inference/model_registry.py +148 -2
  118. llama_stack/providers/utils/inference/openai_compat.py +2 -1
  119. llama_stack/providers/utils/inference/openai_mixin.py +41 -2
  120. llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
  121. llama_stack/providers/utils/memory/vector_store.py +46 -19
  122. llama_stack/providers/utils/responses/responses_store.py +40 -6
  123. llama_stack/providers/utils/safety.py +114 -0
  124. llama_stack/providers/utils/tools/mcp.py +44 -3
  125. llama_stack/testing/api_recorder.py +9 -3
  126. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/METADATA +14 -2
  127. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/RECORD +131 -275
  128. llama_stack-0.5.0rc1.dist-info/top_level.txt +1 -0
  129. llama_stack/distributions/meta-reference-gpu/__init__.py +0 -7
  130. llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
  131. llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
  132. llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
  133. llama_stack/models/llama/hadamard_utils.py +0 -88
  134. llama_stack/models/llama/llama3/args.py +0 -74
  135. llama_stack/models/llama/llama3/generation.py +0 -378
  136. llama_stack/models/llama/llama3/model.py +0 -304
  137. llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
  138. llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
  139. llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
  140. llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
  141. llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
  142. llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
  143. llama_stack/models/llama/llama3/quantization/loader.py +0 -316
  144. llama_stack/models/llama/llama3_1/__init__.py +0 -12
  145. llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
  146. llama_stack/models/llama/llama3_1/prompts.py +0 -258
  147. llama_stack/models/llama/llama3_2/__init__.py +0 -5
  148. llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
  149. llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
  150. llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
  151. llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
  152. llama_stack/models/llama/llama3_3/__init__.py +0 -5
  153. llama_stack/models/llama/llama3_3/prompts.py +0 -259
  154. llama_stack/models/llama/llama4/args.py +0 -107
  155. llama_stack/models/llama/llama4/ffn.py +0 -58
  156. llama_stack/models/llama/llama4/moe.py +0 -214
  157. llama_stack/models/llama/llama4/preprocess.py +0 -435
  158. llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
  159. llama_stack/models/llama/llama4/quantization/loader.py +0 -226
  160. llama_stack/models/llama/llama4/vision/__init__.py +0 -5
  161. llama_stack/models/llama/llama4/vision/embedding.py +0 -210
  162. llama_stack/models/llama/llama4/vision/encoder.py +0 -412
  163. llama_stack/models/llama/quantize_impls.py +0 -316
  164. llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
  165. llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
  166. llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
  167. llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
  168. llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
  169. llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
  170. llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
  171. llama_stack-0.4.3.dist-info/top_level.txt +0 -2
  172. llama_stack_api/__init__.py +0 -945
  173. llama_stack_api/admin/__init__.py +0 -45
  174. llama_stack_api/admin/api.py +0 -72
  175. llama_stack_api/admin/fastapi_routes.py +0 -117
  176. llama_stack_api/admin/models.py +0 -113
  177. llama_stack_api/agents.py +0 -173
  178. llama_stack_api/batches/__init__.py +0 -40
  179. llama_stack_api/batches/api.py +0 -53
  180. llama_stack_api/batches/fastapi_routes.py +0 -113
  181. llama_stack_api/batches/models.py +0 -78
  182. llama_stack_api/benchmarks/__init__.py +0 -43
  183. llama_stack_api/benchmarks/api.py +0 -39
  184. llama_stack_api/benchmarks/fastapi_routes.py +0 -109
  185. llama_stack_api/benchmarks/models.py +0 -109
  186. llama_stack_api/common/__init__.py +0 -5
  187. llama_stack_api/common/content_types.py +0 -101
  188. llama_stack_api/common/errors.py +0 -95
  189. llama_stack_api/common/job_types.py +0 -38
  190. llama_stack_api/common/responses.py +0 -77
  191. llama_stack_api/common/training_types.py +0 -47
  192. llama_stack_api/common/type_system.py +0 -146
  193. llama_stack_api/connectors.py +0 -146
  194. llama_stack_api/conversations.py +0 -270
  195. llama_stack_api/datasetio.py +0 -55
  196. llama_stack_api/datasets/__init__.py +0 -61
  197. llama_stack_api/datasets/api.py +0 -35
  198. llama_stack_api/datasets/fastapi_routes.py +0 -104
  199. llama_stack_api/datasets/models.py +0 -152
  200. llama_stack_api/datatypes.py +0 -373
  201. llama_stack_api/eval.py +0 -137
  202. llama_stack_api/file_processors/__init__.py +0 -27
  203. llama_stack_api/file_processors/api.py +0 -64
  204. llama_stack_api/file_processors/fastapi_routes.py +0 -78
  205. llama_stack_api/file_processors/models.py +0 -42
  206. llama_stack_api/files/__init__.py +0 -35
  207. llama_stack_api/files/api.py +0 -51
  208. llama_stack_api/files/fastapi_routes.py +0 -124
  209. llama_stack_api/files/models.py +0 -107
  210. llama_stack_api/inference.py +0 -1169
  211. llama_stack_api/inspect_api/__init__.py +0 -37
  212. llama_stack_api/inspect_api/api.py +0 -25
  213. llama_stack_api/inspect_api/fastapi_routes.py +0 -76
  214. llama_stack_api/inspect_api/models.py +0 -28
  215. llama_stack_api/internal/kvstore.py +0 -28
  216. llama_stack_api/internal/sqlstore.py +0 -81
  217. llama_stack_api/llama_stack_api/__init__.py +0 -945
  218. llama_stack_api/llama_stack_api/admin/__init__.py +0 -45
  219. llama_stack_api/llama_stack_api/admin/api.py +0 -72
  220. llama_stack_api/llama_stack_api/admin/fastapi_routes.py +0 -117
  221. llama_stack_api/llama_stack_api/admin/models.py +0 -113
  222. llama_stack_api/llama_stack_api/agents.py +0 -173
  223. llama_stack_api/llama_stack_api/batches/__init__.py +0 -40
  224. llama_stack_api/llama_stack_api/batches/api.py +0 -53
  225. llama_stack_api/llama_stack_api/batches/fastapi_routes.py +0 -113
  226. llama_stack_api/llama_stack_api/batches/models.py +0 -78
  227. llama_stack_api/llama_stack_api/benchmarks/__init__.py +0 -43
  228. llama_stack_api/llama_stack_api/benchmarks/api.py +0 -39
  229. llama_stack_api/llama_stack_api/benchmarks/fastapi_routes.py +0 -109
  230. llama_stack_api/llama_stack_api/benchmarks/models.py +0 -109
  231. llama_stack_api/llama_stack_api/common/__init__.py +0 -5
  232. llama_stack_api/llama_stack_api/common/content_types.py +0 -101
  233. llama_stack_api/llama_stack_api/common/errors.py +0 -95
  234. llama_stack_api/llama_stack_api/common/job_types.py +0 -38
  235. llama_stack_api/llama_stack_api/common/responses.py +0 -77
  236. llama_stack_api/llama_stack_api/common/training_types.py +0 -47
  237. llama_stack_api/llama_stack_api/common/type_system.py +0 -146
  238. llama_stack_api/llama_stack_api/connectors.py +0 -146
  239. llama_stack_api/llama_stack_api/conversations.py +0 -270
  240. llama_stack_api/llama_stack_api/datasetio.py +0 -55
  241. llama_stack_api/llama_stack_api/datasets/__init__.py +0 -61
  242. llama_stack_api/llama_stack_api/datasets/api.py +0 -35
  243. llama_stack_api/llama_stack_api/datasets/fastapi_routes.py +0 -104
  244. llama_stack_api/llama_stack_api/datasets/models.py +0 -152
  245. llama_stack_api/llama_stack_api/datatypes.py +0 -373
  246. llama_stack_api/llama_stack_api/eval.py +0 -137
  247. llama_stack_api/llama_stack_api/file_processors/__init__.py +0 -27
  248. llama_stack_api/llama_stack_api/file_processors/api.py +0 -64
  249. llama_stack_api/llama_stack_api/file_processors/fastapi_routes.py +0 -78
  250. llama_stack_api/llama_stack_api/file_processors/models.py +0 -42
  251. llama_stack_api/llama_stack_api/files/__init__.py +0 -35
  252. llama_stack_api/llama_stack_api/files/api.py +0 -51
  253. llama_stack_api/llama_stack_api/files/fastapi_routes.py +0 -124
  254. llama_stack_api/llama_stack_api/files/models.py +0 -107
  255. llama_stack_api/llama_stack_api/inference.py +0 -1169
  256. llama_stack_api/llama_stack_api/inspect_api/__init__.py +0 -37
  257. llama_stack_api/llama_stack_api/inspect_api/api.py +0 -25
  258. llama_stack_api/llama_stack_api/inspect_api/fastapi_routes.py +0 -76
  259. llama_stack_api/llama_stack_api/inspect_api/models.py +0 -28
  260. llama_stack_api/llama_stack_api/internal/__init__.py +0 -9
  261. llama_stack_api/llama_stack_api/internal/kvstore.py +0 -28
  262. llama_stack_api/llama_stack_api/internal/sqlstore.py +0 -81
  263. llama_stack_api/llama_stack_api/models.py +0 -171
  264. llama_stack_api/llama_stack_api/openai_responses.py +0 -1468
  265. llama_stack_api/llama_stack_api/post_training.py +0 -370
  266. llama_stack_api/llama_stack_api/prompts.py +0 -203
  267. llama_stack_api/llama_stack_api/providers/__init__.py +0 -33
  268. llama_stack_api/llama_stack_api/providers/api.py +0 -16
  269. llama_stack_api/llama_stack_api/providers/fastapi_routes.py +0 -57
  270. llama_stack_api/llama_stack_api/providers/models.py +0 -24
  271. llama_stack_api/llama_stack_api/py.typed +0 -0
  272. llama_stack_api/llama_stack_api/rag_tool.py +0 -168
  273. llama_stack_api/llama_stack_api/resource.py +0 -37
  274. llama_stack_api/llama_stack_api/router_utils.py +0 -160
  275. llama_stack_api/llama_stack_api/safety.py +0 -132
  276. llama_stack_api/llama_stack_api/schema_utils.py +0 -208
  277. llama_stack_api/llama_stack_api/scoring.py +0 -93
  278. llama_stack_api/llama_stack_api/scoring_functions.py +0 -211
  279. llama_stack_api/llama_stack_api/shields.py +0 -93
  280. llama_stack_api/llama_stack_api/tools.py +0 -226
  281. llama_stack_api/llama_stack_api/vector_io.py +0 -941
  282. llama_stack_api/llama_stack_api/vector_stores.py +0 -53
  283. llama_stack_api/llama_stack_api/version.py +0 -9
  284. llama_stack_api/models.py +0 -171
  285. llama_stack_api/openai_responses.py +0 -1468
  286. llama_stack_api/post_training.py +0 -370
  287. llama_stack_api/prompts.py +0 -203
  288. llama_stack_api/providers/__init__.py +0 -33
  289. llama_stack_api/providers/api.py +0 -16
  290. llama_stack_api/providers/fastapi_routes.py +0 -57
  291. llama_stack_api/providers/models.py +0 -24
  292. llama_stack_api/py.typed +0 -0
  293. llama_stack_api/rag_tool.py +0 -168
  294. llama_stack_api/resource.py +0 -37
  295. llama_stack_api/router_utils.py +0 -160
  296. llama_stack_api/safety.py +0 -132
  297. llama_stack_api/schema_utils.py +0 -208
  298. llama_stack_api/scoring.py +0 -93
  299. llama_stack_api/scoring_functions.py +0 -211
  300. llama_stack_api/shields.py +0 -93
  301. llama_stack_api/tools.py +0 -226
  302. llama_stack_api/vector_io.py +0 -941
  303. llama_stack_api/vector_stores.py +0 -53
  304. llama_stack_api/version.py +0 -9
  305. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/WHEEL +0 -0
  306. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/entry_points.txt +0 -0
  307. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
@@ -23,6 +23,7 @@ from llama_stack_api import (
23
23
  OpenAICompletionRequestWithExtraBody,
24
24
  OpenAIEmbeddingsRequestWithExtraBody,
25
25
  OpenAIEmbeddingsResponse,
26
+ validate_embeddings_input_is_text,
26
27
  )
27
28
 
28
29
  logger = get_logger(name=__name__, category="providers::remote::watsonx")
@@ -147,6 +148,9 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
147
148
  """
148
149
  Override parent method to add watsonx-specific parameters.
149
150
  """
151
+ # Validate that input contains only text, not token arrays
152
+ validate_embeddings_input_is_text(params)
153
+
150
154
  model_obj = await self.model_store.get_model(params.model)
151
155
 
152
156
  # Convert input to list if it's a string
@@ -0,0 +1,151 @@
1
+ # NVIDIA Post-Training Provider for LlamaStack
2
+
3
+ This provider enables fine-tuning of LLMs using NVIDIA's NeMo Customizer service.
4
+
5
+ ## Features
6
+
7
+ - Supervised fine-tuning of Llama models
8
+ - LoRA fine-tuning support
9
+ - Job management and status tracking
10
+
11
+ ## Getting Started
12
+
13
+ ### Prerequisites
14
+
15
+ - LlamaStack with NVIDIA configuration
16
+ - Access to Hosted NVIDIA NeMo Customizer service
17
+ - Dataset registered in the Hosted NVIDIA NeMo Customizer service
18
+ - Base model downloaded and available in the Hosted NVIDIA NeMo Customizer service
19
+
20
+ ### Setup
21
+
22
+ Build the NVIDIA environment:
23
+
24
+ ```bash
25
+ uv pip install llama-stack-client
26
+ uv run llama stack list-deps nvidia | xargs -L1 uv pip install
27
+ ```
28
+
29
+ ### Basic Usage using the LlamaStack Python Client
30
+
31
+ ### Create Customization Job
32
+
33
+ #### Initialize the client
34
+
35
+ ```python
36
+ import os
37
+
38
+ os.environ["NVIDIA_API_KEY"] = "your-api-key"
39
+ os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
40
+ os.environ["NVIDIA_DATASET_NAMESPACE"] = "default"
41
+ os.environ["NVIDIA_PROJECT_ID"] = "test-project"
42
+ os.environ["NVIDIA_OUTPUT_MODEL_DIR"] = "test-example-model@v1"
43
+
44
+ from llama_stack.core.library_client import LlamaStackAsLibraryClient
45
+
46
+ client = LlamaStackAsLibraryClient("nvidia")
47
+ client.initialize()
48
+ ```
49
+
50
+ #### Configure fine-tuning parameters
51
+
52
+ ```python
53
+ from llama_stack_client.types.post_training_supervised_fine_tune_params import (
54
+ TrainingConfig,
55
+ TrainingConfigDataConfig,
56
+ TrainingConfigOptimizerConfig,
57
+ )
58
+ from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig
59
+ ```
60
+
61
+ #### Set up LoRA configuration
62
+
63
+ ```python
64
+ algorithm_config = LoraFinetuningConfig(type="LoRA", adapter_dim=16)
65
+ ```
66
+
67
+ #### Configure training data
68
+
69
+ ```python
70
+ data_config = TrainingConfigDataConfig(
71
+ dataset_id="your-dataset-id", # Use client.datasets.list() to see available datasets
72
+ batch_size=16,
73
+ )
74
+ ```
75
+
76
+ #### Configure optimizer
77
+
78
+ ```python
79
+ optimizer_config = TrainingConfigOptimizerConfig(
80
+ lr=0.0001,
81
+ )
82
+ ```
83
+
84
+ #### Set up training configuration
85
+
86
+ ```python
87
+ training_config = TrainingConfig(
88
+ n_epochs=2,
89
+ data_config=data_config,
90
+ optimizer_config=optimizer_config,
91
+ )
92
+ ```
93
+
94
+ #### Start fine-tuning job
95
+
96
+ ```python
97
+ training_job = client.post_training.supervised_fine_tune(
98
+ job_uuid="unique-job-id",
99
+ model="meta-llama/Llama-3.1-8B-Instruct",
100
+ checkpoint_dir="",
101
+ algorithm_config=algorithm_config,
102
+ training_config=training_config,
103
+ logger_config={},
104
+ hyperparam_search_config={},
105
+ )
106
+ ```
107
+
108
+ ### List all jobs
109
+
110
+ ```python
111
+ jobs = client.post_training.job.list()
112
+ ```
113
+
114
+ ### Check job status
115
+
116
+ ```python
117
+ job_status = client.post_training.job.status(job_uuid="your-job-id")
118
+ ```
119
+
120
+ ### Cancel a job
121
+
122
+ ```python
123
+ client.post_training.job.cancel(job_uuid="your-job-id")
124
+ ```
125
+
126
+ ### Inference with the fine-tuned model
127
+
128
+ #### 1. Register the model
129
+
130
+ ```python
131
+ from llama_stack_api.models import Model, ModelType
132
+
133
+ client.models.register(
134
+ model_id="test-example-model@v1",
135
+ provider_id="nvidia",
136
+ provider_model_id="test-example-model@v1",
137
+ model_type=ModelType.llm,
138
+ )
139
+ ```
140
+
141
+ #### 2. Inference with the fine-tuned model
142
+
143
+ ```python
144
+ response = client.completions.create(
145
+ prompt="Complete the sentence using one word: Roses are red, violets are ",
146
+ stream=False,
147
+ model="test-example-model@v1",
148
+ max_tokens=50,
149
+ )
150
+ print(response.choices[0].text)
151
+ ```
@@ -14,13 +14,15 @@ from llama_stack.providers.remote.post_training.nvidia.config import NvidiaPostT
14
14
  from llama_stack.providers.remote.post_training.nvidia.utils import warn_unsupported_params
15
15
  from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
16
16
  from llama_stack_api import (
17
- AlgorithmConfig,
18
- DPOAlignmentConfig,
17
+ CancelTrainingJobRequest,
18
+ GetTrainingJobArtifactsRequest,
19
+ GetTrainingJobStatusRequest,
19
20
  JobStatus,
20
21
  PostTrainingJob,
21
22
  PostTrainingJobArtifactsResponse,
22
23
  PostTrainingJobStatusResponse,
23
- TrainingConfig,
24
+ PreferenceOptimizeRequest,
25
+ SupervisedFineTuneRequest,
24
26
  )
25
27
 
26
28
  from .models import _MODEL_ENTRIES
@@ -156,7 +158,9 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
156
158
 
157
159
  return ListNvidiaPostTrainingJobs(data=jobs)
158
160
 
159
- async def get_training_job_status(self, job_uuid: str) -> NvidiaPostTrainingJobStatusResponse:
161
+ async def get_training_job_status(
162
+ self, request: GetTrainingJobStatusRequest
163
+ ) -> NvidiaPostTrainingJobStatusResponse:
160
164
  """Get the status of a customization job.
161
165
  Updated the base class return type from PostTrainingJobResponse to NvidiaPostTrainingJob.
162
166
 
@@ -178,8 +182,8 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
178
182
  """
179
183
  response = await self._make_request(
180
184
  "GET",
181
- f"/v1/customization/jobs/{job_uuid}/status",
182
- params={"job_id": job_uuid},
185
+ f"/v1/customization/jobs/{request.job_uuid}/status",
186
+ params={"job_id": request.job_uuid},
183
187
  )
184
188
 
185
189
  api_status = response.pop("status").lower()
@@ -187,18 +191,20 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
187
191
 
188
192
  return NvidiaPostTrainingJobStatusResponse(
189
193
  status=JobStatus(mapped_status),
190
- job_uuid=job_uuid,
194
+ job_uuid=request.job_uuid,
191
195
  started_at=datetime.fromisoformat(response.pop("created_at")),
192
196
  updated_at=datetime.fromisoformat(response.pop("updated_at")),
193
197
  **response,
194
198
  )
195
199
 
196
- async def cancel_training_job(self, job_uuid: str) -> None:
200
+ async def cancel_training_job(self, request: CancelTrainingJobRequest) -> None:
197
201
  await self._make_request(
198
- method="POST", path=f"/v1/customization/jobs/{job_uuid}/cancel", params={"job_id": job_uuid}
202
+ method="POST", path=f"/v1/customization/jobs/{request.job_uuid}/cancel", params={"job_id": request.job_uuid}
199
203
  )
200
204
 
201
- async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
205
+ async def get_training_job_artifacts(
206
+ self, request: GetTrainingJobArtifactsRequest
207
+ ) -> PostTrainingJobArtifactsResponse:
202
208
  raise NotImplementedError("Job artifacts are not implemented yet")
203
209
 
204
210
  async def get_post_training_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
@@ -206,13 +212,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
206
212
 
207
213
  async def supervised_fine_tune(
208
214
  self,
209
- job_uuid: str,
210
- training_config: dict[str, Any],
211
- hyperparam_search_config: dict[str, Any],
212
- logger_config: dict[str, Any],
213
- model: str,
214
- checkpoint_dir: str | None,
215
- algorithm_config: AlgorithmConfig | None = None,
215
+ request: SupervisedFineTuneRequest,
216
216
  ) -> NvidiaPostTrainingJob:
217
217
  """
218
218
  Fine-tunes a model on a dataset.
@@ -300,13 +300,16 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
300
300
  User is informed about unsupported parameters via warnings.
301
301
  """
302
302
 
303
+ # Convert training_config to dict for internal processing
304
+ training_config = request.training_config.model_dump()
305
+
303
306
  # Check for unsupported method parameters
304
307
  unsupported_method_params = []
305
- if checkpoint_dir:
306
- unsupported_method_params.append(f"checkpoint_dir={checkpoint_dir}")
307
- if hyperparam_search_config:
308
+ if request.checkpoint_dir:
309
+ unsupported_method_params.append(f"checkpoint_dir={request.checkpoint_dir}")
310
+ if request.hyperparam_search_config:
308
311
  unsupported_method_params.append("hyperparam_search_config")
309
- if logger_config:
312
+ if request.logger_config:
310
313
  unsupported_method_params.append("logger_config")
311
314
 
312
315
  if unsupported_method_params:
@@ -344,7 +347,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
344
347
 
345
348
  # Prepare base job configuration
346
349
  job_config = {
347
- "config": model,
350
+ "config": request.model,
348
351
  "dataset": {
349
352
  "name": training_config["data_config"]["dataset_id"],
350
353
  "namespace": self.config.dataset_namespace,
@@ -388,14 +391,14 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
388
391
  job_config["hyperparameters"].pop("sft")
389
392
 
390
393
  # Handle LoRA-specific configuration
391
- if algorithm_config:
392
- if algorithm_config.type == "LoRA":
393
- warn_unsupported_params(algorithm_config, supported_params["lora_config"], "LoRA config")
394
+ if request.algorithm_config:
395
+ if request.algorithm_config.type == "LoRA":
396
+ warn_unsupported_params(request.algorithm_config, supported_params["lora_config"], "LoRA config")
394
397
  job_config["hyperparameters"]["lora"] = {
395
- k: v for k, v in {"alpha": algorithm_config.alpha}.items() if v is not None
398
+ k: v for k, v in {"alpha": request.algorithm_config.alpha}.items() if v is not None
396
399
  }
397
400
  else:
398
- raise NotImplementedError(f"Unsupported algorithm config: {algorithm_config}")
401
+ raise NotImplementedError(f"Unsupported algorithm config: {request.algorithm_config}")
399
402
 
400
403
  # Create the customization job
401
404
  response = await self._make_request(
@@ -416,12 +419,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
416
419
 
417
420
  async def preference_optimize(
418
421
  self,
419
- job_uuid: str,
420
- finetuned_model: str,
421
- algorithm_config: DPOAlignmentConfig,
422
- training_config: TrainingConfig,
423
- hyperparam_search_config: dict[str, Any],
424
- logger_config: dict[str, Any],
422
+ request: PreferenceOptimizeRequest,
425
423
  ) -> PostTrainingJob:
426
424
  """Optimize a model based on preference data."""
427
425
  raise NotImplementedError("Preference optimization is not implemented yet")
@@ -5,12 +5,13 @@
5
5
  # the root directory of this source tree.
6
6
 
7
7
  import json
8
- from typing import Any
9
8
 
10
9
  from llama_stack.log import get_logger
11
10
  from llama_stack.providers.utils.bedrock.client import create_bedrock_client
11
+ from llama_stack.providers.utils.safety import ShieldToModerationMixin
12
12
  from llama_stack_api import (
13
- OpenAIMessageParam,
13
+ GetShieldRequest,
14
+ RunShieldRequest,
14
15
  RunShieldResponse,
15
16
  Safety,
16
17
  SafetyViolation,
@@ -24,7 +25,7 @@ from .config import BedrockSafetyConfig
24
25
  logger = get_logger(name=__name__, category="safety::bedrock")
25
26
 
26
27
 
27
- class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
28
+ class BedrockSafetyAdapter(ShieldToModerationMixin, Safety, ShieldsProtocolPrivate):
28
29
  def __init__(self, config: BedrockSafetyConfig) -> None:
29
30
  self.config = config
30
31
  self.registered_shields = []
@@ -55,49 +56,31 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
55
56
  async def unregister_shield(self, identifier: str) -> None:
56
57
  pass
57
58
 
58
- async def run_shield(
59
- self, shield_id: str, messages: list[OpenAIMessageParam], params: dict[str, Any] = None
60
- ) -> RunShieldResponse:
61
- shield = await self.shield_store.get_shield(shield_id)
59
+ async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse:
60
+ shield = await self.shield_store.get_shield(GetShieldRequest(identifier=request.shield_id))
62
61
  if not shield:
63
- raise ValueError(f"Shield {shield_id} not found")
64
-
65
- """
66
- This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
67
- ```content = [
68
- {
69
- "text": {
70
- "text": "Is the AB503 Product a better investment than the S&P 500?"
71
- }
72
- }
73
- ]```
74
- Incoming messages contain content, role . For now we will extract the content and
75
- default the "qualifiers": ["query"]
76
- """
62
+ raise ValueError(f"Shield {request.shield_id} not found")
77
63
 
78
64
  shield_params = shield.params
79
- logger.debug(f"run_shield::{shield_params}::messages={messages}")
65
+ logger.debug(f"run_shield::{shield_params}::messages={request.messages}")
80
66
 
81
- # - convert the messages into format Bedrock expects
82
67
  content_messages = []
83
- for message in messages:
68
+ for message in request.messages:
84
69
  content_messages.append({"text": {"text": message.content}})
85
70
  logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
86
71
 
87
72
  response = self.bedrock_runtime_client.apply_guardrail(
88
73
  guardrailIdentifier=shield.provider_resource_id,
89
74
  guardrailVersion=shield_params["guardrailVersion"],
90
- source="OUTPUT", # or 'INPUT' depending on your use case
75
+ source="OUTPUT",
91
76
  content=content_messages,
92
77
  )
93
78
  if response["action"] == "GUARDRAIL_INTERVENED":
94
79
  user_message = ""
95
80
  metadata = {}
96
81
  for output in response["outputs"]:
97
- # guardrails returns a list - however for this implementation we will leverage the last values
98
82
  user_message = output["text"]
99
83
  for assessment in response["assessments"]:
100
- # guardrails returns a list - however for this implementation we will leverage the last values
101
84
  metadata = dict(assessment)
102
85
 
103
86
  return RunShieldResponse(
@@ -0,0 +1,78 @@
1
+ # NVIDIA Safety Provider for LlamaStack
2
+
3
+ This provider enables safety checks and guardrails for LLM interactions using NVIDIA's NeMo Guardrails service.
4
+
5
+ ## Features
6
+
7
+ - Run safety checks for messages
8
+
9
+ ## Getting Started
10
+
11
+ ### Prerequisites
12
+
13
+ - LlamaStack with NVIDIA configuration
14
+ - Access to NVIDIA NeMo Guardrails service
15
+ - NIM for model to use for safety check is deployed
16
+
17
+ ### Setup
18
+
19
+ Build the NVIDIA environment:
20
+
21
+ ```bash
22
+ uv pip install llama-stack-client
23
+ uv run llama stack list-deps nvidia | xargs -L1 uv pip install
24
+ ```
25
+
26
+ ### Basic Usage using the LlamaStack Python Client
27
+
28
+ #### Initialize the client
29
+
30
+ ```python
31
+ import os
32
+
33
+ os.environ["NVIDIA_API_KEY"] = "your-api-key"
34
+ os.environ["NVIDIA_GUARDRAILS_URL"] = "http://guardrails.test"
35
+
36
+ from llama_stack.core.library_client import LlamaStackAsLibraryClient
37
+
38
+ client = LlamaStackAsLibraryClient("nvidia")
39
+ client.initialize()
40
+ ```
41
+
42
+ #### Create a safety shield
43
+
44
+ ```python
45
+ from llama_stack_api.safety import Shield
46
+ from llama_stack_api.inference import Message
47
+
48
+ # Create a safety shield
49
+ shield = Shield(
50
+ shield_id="your-shield-id",
51
+ provider_resource_id="safety-model-id", # The model to use for safety checks
52
+ description="Safety checks for content moderation",
53
+ )
54
+
55
+ # Register the shield
56
+ await client.safety.register_shield(shield)
57
+ ```
58
+
59
+ #### Run safety checks
60
+
61
+ ```python
62
+ # Messages to check
63
+ messages = [Message(role="user", content="Your message to check")]
64
+
65
+ # Run safety check
66
+ response = await client.safety.run_shield(
67
+ shield_id="your-shield-id",
68
+ messages=messages,
69
+ )
70
+
71
+ # Check for violations
72
+ if response.violation:
73
+ print(f"Safety violation detected: {response.violation.user_message}")
74
+ print(f"Violation level: {response.violation.violation_level}")
75
+ print(f"Metadata: {response.violation.metadata}")
76
+ else:
77
+ print("No safety violations detected")
78
+ ```
@@ -9,9 +9,11 @@ from typing import Any
9
9
  import requests
10
10
 
11
11
  from llama_stack.log import get_logger
12
+ from llama_stack.providers.utils.safety import ShieldToModerationMixin
12
13
  from llama_stack_api import (
13
- ModerationObject,
14
+ GetShieldRequest,
14
15
  OpenAIMessageParam,
16
+ RunShieldRequest,
15
17
  RunShieldResponse,
16
18
  Safety,
17
19
  SafetyViolation,
@@ -25,7 +27,7 @@ from .config import NVIDIASafetyConfig
25
27
  logger = get_logger(name=__name__, category="safety::nvidia")
26
28
 
27
29
 
28
- class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
30
+ class NVIDIASafetyAdapter(ShieldToModerationMixin, Safety, ShieldsProtocolPrivate):
29
31
  def __init__(self, config: NVIDIASafetyConfig) -> None:
30
32
  """
31
33
  Initialize the NVIDIASafetyAdapter with a given safety configuration.
@@ -48,32 +50,14 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
48
50
  async def unregister_shield(self, identifier: str) -> None:
49
51
  pass
50
52
 
51
- async def run_shield(
52
- self, shield_id: str, messages: list[OpenAIMessageParam], params: dict[str, Any] | None = None
53
- ) -> RunShieldResponse:
54
- """
55
- Run a safety shield check against the provided messages.
56
-
57
- Args:
58
- shield_id (str): The unique identifier for the shield to be used.
59
- messages (List[Message]): A list of Message objects representing the conversation history.
60
- params (Optional[dict[str, Any]]): Additional parameters for the shield check.
61
-
62
- Returns:
63
- RunShieldResponse: The response containing safety violation details if any.
64
-
65
- Raises:
66
- ValueError: If the shield with the provided shield_id is not found.
67
- """
68
- shield = await self.shield_store.get_shield(shield_id)
53
+ async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse:
54
+ """Run a safety shield check against the provided messages."""
55
+ shield = await self.shield_store.get_shield(GetShieldRequest(identifier=request.shield_id))
69
56
  if not shield:
70
- raise ValueError(f"Shield {shield_id} not found")
57
+ raise ValueError(f"Shield {request.shield_id} not found")
71
58
 
72
59
  self.shield = NeMoGuardrails(self.config, shield.shield_id)
73
- return await self.shield.run(messages)
74
-
75
- async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
76
- raise NotImplementedError("NVIDIA safety provider currently does not implement run_moderation")
60
+ return await self.shield.run(request.messages)
77
61
 
78
62
 
79
63
  class NeMoGuardrails:
@@ -4,15 +4,15 @@
4
4
  # This source code is licensed under the terms described in the LICENSE file in
5
5
  # the root directory of this source tree.
6
6
 
7
- from typing import Any
8
-
9
7
  import litellm
10
8
  import requests
11
9
 
12
10
  from llama_stack.core.request_headers import NeedsRequestProviderData
13
11
  from llama_stack.log import get_logger
12
+ from llama_stack.providers.utils.safety import ShieldToModerationMixin
14
13
  from llama_stack_api import (
15
- OpenAIMessageParam,
14
+ GetShieldRequest,
15
+ RunShieldRequest,
16
16
  RunShieldResponse,
17
17
  Safety,
18
18
  SafetyViolation,
@@ -28,7 +28,7 @@ logger = get_logger(name=__name__, category="safety::sambanova")
28
28
  CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
29
29
 
30
30
 
31
- class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProviderData):
31
+ class SambaNovaSafetyAdapter(ShieldToModerationMixin, Safety, ShieldsProtocolPrivate, NeedsRequestProviderData):
32
32
  def __init__(self, config: SambaNovaSafetyConfig) -> None:
33
33
  self.config = config
34
34
  self.environment_available_models = []
@@ -69,17 +69,19 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide
69
69
  async def unregister_shield(self, identifier: str) -> None:
70
70
  pass
71
71
 
72
- async def run_shield(
73
- self, shield_id: str, messages: list[OpenAIMessageParam], params: dict[str, Any] | None = None
74
- ) -> RunShieldResponse:
75
- shield = await self.shield_store.get_shield(shield_id)
72
+ async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse:
73
+ shield = await self.shield_store.get_shield(GetShieldRequest(identifier=request.shield_id))
76
74
  if not shield:
77
- raise ValueError(f"Shield {shield_id} not found")
75
+ raise ValueError(f"Shield {request.shield_id} not found")
78
76
 
79
77
  shield_params = shield.params
80
- logger.debug(f"run_shield::{shield_params}::messages={messages}")
78
+ logger.debug(f"run_shield::{shield_params}::messages={request.messages}")
81
79
 
82
- response = litellm.completion(model=shield.provider_resource_id, messages=messages, api_key=self._get_api_key())
80
+ response = litellm.completion(
81
+ model=shield.provider_resource_id,
82
+ messages=request.messages,
83
+ api_key=self._get_api_key(),
84
+ )
83
85
  shield_message = response.choices[0].message.content
84
86
 
85
87
  if "unsafe" in shield_message.lower():
@@ -0,0 +1,17 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ from llama_stack_api import Api, ProviderSpec
8
+
9
+ from .config import ElasticsearchVectorIOConfig
10
+
11
+
12
+ async def get_adapter_impl(config: ElasticsearchVectorIOConfig, deps: dict[Api, ProviderSpec]):
13
+ from .elasticsearch import ElasticsearchVectorIOAdapter
14
+
15
+ impl = ElasticsearchVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
16
+ await impl.initialize()
17
+ return impl
@@ -0,0 +1,32 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ from typing import Any
8
+
9
+ from pydantic import BaseModel, Field
10
+
11
+ from llama_stack.core.storage.datatypes import KVStoreReference
12
+ from llama_stack_api import json_schema_type
13
+
14
+
15
+ @json_schema_type
16
+ class ElasticsearchVectorIOConfig(BaseModel):
17
+ elasticsearch_api_key: str | None = Field(description="The API key for the Elasticsearch instance", default=None)
18
+ elasticsearch_url: str | None = Field(description="The URL of the Elasticsearch instance", default="localhost:9200")
19
+ persistence: KVStoreReference | None = Field(
20
+ description="Config for KV store backend (SQLite only for now)", default=None
21
+ )
22
+
23
+ @classmethod
24
+ def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
25
+ return {
26
+ "elasticsearch_url": "${env.ELASTICSEARCH_URL:=localhost:9200}",
27
+ "elasticsearch_api_key": "${env.ELASTICSEARCH_API_KEY:=}",
28
+ "persistence": KVStoreReference(
29
+ backend="kv_default",
30
+ namespace="vector_io::elasticsearch",
31
+ ).model_dump(exclude_none=True),
32
+ }