llama-stack 0.4.3__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (311) hide show
  1. llama_stack/cli/stack/_list_deps.py +11 -7
  2. llama_stack/cli/stack/run.py +3 -25
  3. llama_stack/core/access_control/datatypes.py +78 -0
  4. llama_stack/core/configure.py +2 -2
  5. {llama_stack_api/internal → llama_stack/core/connectors}/__init__.py +2 -2
  6. llama_stack/core/connectors/connectors.py +162 -0
  7. llama_stack/core/conversations/conversations.py +61 -58
  8. llama_stack/core/datatypes.py +54 -8
  9. llama_stack/core/library_client.py +60 -13
  10. llama_stack/core/prompts/prompts.py +43 -42
  11. llama_stack/core/routers/datasets.py +20 -17
  12. llama_stack/core/routers/eval_scoring.py +143 -53
  13. llama_stack/core/routers/inference.py +20 -9
  14. llama_stack/core/routers/safety.py +30 -42
  15. llama_stack/core/routers/vector_io.py +15 -7
  16. llama_stack/core/routing_tables/models.py +42 -3
  17. llama_stack/core/routing_tables/scoring_functions.py +19 -19
  18. llama_stack/core/routing_tables/shields.py +20 -17
  19. llama_stack/core/routing_tables/vector_stores.py +8 -5
  20. llama_stack/core/server/auth.py +192 -17
  21. llama_stack/core/server/fastapi_router_registry.py +40 -5
  22. llama_stack/core/server/server.py +24 -5
  23. llama_stack/core/stack.py +54 -10
  24. llama_stack/core/storage/datatypes.py +9 -0
  25. llama_stack/core/store/registry.py +1 -1
  26. llama_stack/core/utils/exec.py +2 -2
  27. llama_stack/core/utils/type_inspection.py +16 -2
  28. llama_stack/distributions/dell/config.yaml +4 -1
  29. llama_stack/distributions/dell/doc_template.md +209 -0
  30. llama_stack/distributions/dell/run-with-safety.yaml +4 -1
  31. llama_stack/distributions/nvidia/config.yaml +4 -1
  32. llama_stack/distributions/nvidia/doc_template.md +170 -0
  33. llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
  34. llama_stack/distributions/oci/config.yaml +4 -1
  35. llama_stack/distributions/oci/doc_template.md +140 -0
  36. llama_stack/distributions/open-benchmark/config.yaml +9 -1
  37. llama_stack/distributions/postgres-demo/config.yaml +1 -1
  38. llama_stack/distributions/starter/build.yaml +62 -0
  39. llama_stack/distributions/starter/config.yaml +22 -3
  40. llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
  41. llama_stack/distributions/starter/starter.py +13 -1
  42. llama_stack/distributions/starter-gpu/build.yaml +62 -0
  43. llama_stack/distributions/starter-gpu/config.yaml +22 -3
  44. llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
  45. llama_stack/distributions/template.py +10 -2
  46. llama_stack/distributions/watsonx/config.yaml +4 -1
  47. llama_stack/log.py +1 -0
  48. llama_stack/models/llama/resources/dog.jpg +0 -0
  49. llama_stack/models/llama/resources/pasta.jpeg +0 -0
  50. llama_stack/models/llama/resources/small_dog.jpg +0 -0
  51. llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
  52. llama_stack/providers/inline/agents/meta_reference/agents.py +58 -61
  53. llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +187 -60
  54. llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +99 -22
  55. llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
  56. llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
  57. llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
  58. llama_stack/providers/inline/batches/reference/batches.py +2 -1
  59. llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
  60. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.h +9 -0
  61. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.swift +189 -0
  62. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/Parsing.swift +238 -0
  63. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/PromptTemplate.swift +12 -0
  64. llama_stack/providers/inline/ios/inference/LocalInferenceImpl/SystemPrompts.swift +89 -0
  65. llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.pbxproj +550 -0
  66. llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/contents.xcworkspacedata +7 -0
  67. llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +8 -0
  68. llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
  69. llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
  70. llama_stack/providers/inline/post_training/torchtune/common/utils.py +5 -9
  71. llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
  72. llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
  73. llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
  74. llama_stack/providers/inline/safety/llama_guard/llama_guard.py +20 -24
  75. llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
  76. llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
  77. llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
  78. llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
  79. llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
  80. llama_stack/providers/registry/agents.py +1 -0
  81. llama_stack/providers/registry/inference.py +1 -9
  82. llama_stack/providers/registry/vector_io.py +136 -16
  83. llama_stack/providers/remote/datasetio/nvidia/README.md +74 -0
  84. llama_stack/providers/remote/eval/nvidia/README.md +134 -0
  85. llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
  86. llama_stack/providers/remote/files/s3/README.md +266 -0
  87. llama_stack/providers/remote/files/s3/config.py +5 -3
  88. llama_stack/providers/remote/files/s3/files.py +2 -2
  89. llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
  90. llama_stack/providers/remote/inference/nvidia/NVIDIA.md +203 -0
  91. llama_stack/providers/remote/inference/openai/openai.py +2 -0
  92. llama_stack/providers/remote/inference/together/together.py +4 -0
  93. llama_stack/providers/remote/inference/vertexai/config.py +3 -3
  94. llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
  95. llama_stack/providers/remote/inference/vllm/config.py +37 -18
  96. llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
  97. llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
  98. llama_stack/providers/remote/post_training/nvidia/README.md +151 -0
  99. llama_stack/providers/remote/post_training/nvidia/models.py +3 -11
  100. llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
  101. llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
  102. llama_stack/providers/remote/safety/nvidia/README.md +78 -0
  103. llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
  104. llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
  105. llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
  106. llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
  107. llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
  108. llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
  109. llama_stack/providers/remote/vector_io/oci/config.py +41 -0
  110. llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
  111. llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
  112. llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
  113. llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
  114. llama_stack/providers/utils/bedrock/client.py +3 -3
  115. llama_stack/providers/utils/bedrock/config.py +7 -7
  116. llama_stack/providers/utils/inference/__init__.py +0 -25
  117. llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
  118. llama_stack/providers/utils/inference/http_client.py +239 -0
  119. llama_stack/providers/utils/inference/litellm_openai_mixin.py +6 -0
  120. llama_stack/providers/utils/inference/model_registry.py +148 -2
  121. llama_stack/providers/utils/inference/openai_compat.py +1 -158
  122. llama_stack/providers/utils/inference/openai_mixin.py +42 -2
  123. llama_stack/providers/utils/inference/prompt_adapter.py +0 -209
  124. llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
  125. llama_stack/providers/utils/memory/vector_store.py +46 -19
  126. llama_stack/providers/utils/responses/responses_store.py +40 -6
  127. llama_stack/providers/utils/safety.py +114 -0
  128. llama_stack/providers/utils/tools/mcp.py +44 -3
  129. llama_stack/testing/api_recorder.py +9 -3
  130. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/METADATA +14 -2
  131. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/RECORD +135 -279
  132. llama_stack-0.5.0.dist-info/top_level.txt +1 -0
  133. llama_stack/distributions/meta-reference-gpu/__init__.py +0 -7
  134. llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
  135. llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
  136. llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
  137. llama_stack/models/llama/hadamard_utils.py +0 -88
  138. llama_stack/models/llama/llama3/args.py +0 -74
  139. llama_stack/models/llama/llama3/generation.py +0 -378
  140. llama_stack/models/llama/llama3/model.py +0 -304
  141. llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
  142. llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
  143. llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
  144. llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
  145. llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
  146. llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
  147. llama_stack/models/llama/llama3/quantization/loader.py +0 -316
  148. llama_stack/models/llama/llama3_1/__init__.py +0 -12
  149. llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
  150. llama_stack/models/llama/llama3_1/prompts.py +0 -258
  151. llama_stack/models/llama/llama3_2/__init__.py +0 -5
  152. llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
  153. llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
  154. llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
  155. llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
  156. llama_stack/models/llama/llama3_3/__init__.py +0 -5
  157. llama_stack/models/llama/llama3_3/prompts.py +0 -259
  158. llama_stack/models/llama/llama4/args.py +0 -107
  159. llama_stack/models/llama/llama4/ffn.py +0 -58
  160. llama_stack/models/llama/llama4/moe.py +0 -214
  161. llama_stack/models/llama/llama4/preprocess.py +0 -435
  162. llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
  163. llama_stack/models/llama/llama4/quantization/loader.py +0 -226
  164. llama_stack/models/llama/llama4/vision/__init__.py +0 -5
  165. llama_stack/models/llama/llama4/vision/embedding.py +0 -210
  166. llama_stack/models/llama/llama4/vision/encoder.py +0 -412
  167. llama_stack/models/llama/quantize_impls.py +0 -316
  168. llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
  169. llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
  170. llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
  171. llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
  172. llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
  173. llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
  174. llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
  175. llama_stack-0.4.3.dist-info/top_level.txt +0 -2
  176. llama_stack_api/__init__.py +0 -945
  177. llama_stack_api/admin/__init__.py +0 -45
  178. llama_stack_api/admin/api.py +0 -72
  179. llama_stack_api/admin/fastapi_routes.py +0 -117
  180. llama_stack_api/admin/models.py +0 -113
  181. llama_stack_api/agents.py +0 -173
  182. llama_stack_api/batches/__init__.py +0 -40
  183. llama_stack_api/batches/api.py +0 -53
  184. llama_stack_api/batches/fastapi_routes.py +0 -113
  185. llama_stack_api/batches/models.py +0 -78
  186. llama_stack_api/benchmarks/__init__.py +0 -43
  187. llama_stack_api/benchmarks/api.py +0 -39
  188. llama_stack_api/benchmarks/fastapi_routes.py +0 -109
  189. llama_stack_api/benchmarks/models.py +0 -109
  190. llama_stack_api/common/__init__.py +0 -5
  191. llama_stack_api/common/content_types.py +0 -101
  192. llama_stack_api/common/errors.py +0 -95
  193. llama_stack_api/common/job_types.py +0 -38
  194. llama_stack_api/common/responses.py +0 -77
  195. llama_stack_api/common/training_types.py +0 -47
  196. llama_stack_api/common/type_system.py +0 -146
  197. llama_stack_api/connectors.py +0 -146
  198. llama_stack_api/conversations.py +0 -270
  199. llama_stack_api/datasetio.py +0 -55
  200. llama_stack_api/datasets/__init__.py +0 -61
  201. llama_stack_api/datasets/api.py +0 -35
  202. llama_stack_api/datasets/fastapi_routes.py +0 -104
  203. llama_stack_api/datasets/models.py +0 -152
  204. llama_stack_api/datatypes.py +0 -373
  205. llama_stack_api/eval.py +0 -137
  206. llama_stack_api/file_processors/__init__.py +0 -27
  207. llama_stack_api/file_processors/api.py +0 -64
  208. llama_stack_api/file_processors/fastapi_routes.py +0 -78
  209. llama_stack_api/file_processors/models.py +0 -42
  210. llama_stack_api/files/__init__.py +0 -35
  211. llama_stack_api/files/api.py +0 -51
  212. llama_stack_api/files/fastapi_routes.py +0 -124
  213. llama_stack_api/files/models.py +0 -107
  214. llama_stack_api/inference.py +0 -1169
  215. llama_stack_api/inspect_api/__init__.py +0 -37
  216. llama_stack_api/inspect_api/api.py +0 -25
  217. llama_stack_api/inspect_api/fastapi_routes.py +0 -76
  218. llama_stack_api/inspect_api/models.py +0 -28
  219. llama_stack_api/internal/kvstore.py +0 -28
  220. llama_stack_api/internal/sqlstore.py +0 -81
  221. llama_stack_api/llama_stack_api/__init__.py +0 -945
  222. llama_stack_api/llama_stack_api/admin/__init__.py +0 -45
  223. llama_stack_api/llama_stack_api/admin/api.py +0 -72
  224. llama_stack_api/llama_stack_api/admin/fastapi_routes.py +0 -117
  225. llama_stack_api/llama_stack_api/admin/models.py +0 -113
  226. llama_stack_api/llama_stack_api/agents.py +0 -173
  227. llama_stack_api/llama_stack_api/batches/__init__.py +0 -40
  228. llama_stack_api/llama_stack_api/batches/api.py +0 -53
  229. llama_stack_api/llama_stack_api/batches/fastapi_routes.py +0 -113
  230. llama_stack_api/llama_stack_api/batches/models.py +0 -78
  231. llama_stack_api/llama_stack_api/benchmarks/__init__.py +0 -43
  232. llama_stack_api/llama_stack_api/benchmarks/api.py +0 -39
  233. llama_stack_api/llama_stack_api/benchmarks/fastapi_routes.py +0 -109
  234. llama_stack_api/llama_stack_api/benchmarks/models.py +0 -109
  235. llama_stack_api/llama_stack_api/common/__init__.py +0 -5
  236. llama_stack_api/llama_stack_api/common/content_types.py +0 -101
  237. llama_stack_api/llama_stack_api/common/errors.py +0 -95
  238. llama_stack_api/llama_stack_api/common/job_types.py +0 -38
  239. llama_stack_api/llama_stack_api/common/responses.py +0 -77
  240. llama_stack_api/llama_stack_api/common/training_types.py +0 -47
  241. llama_stack_api/llama_stack_api/common/type_system.py +0 -146
  242. llama_stack_api/llama_stack_api/connectors.py +0 -146
  243. llama_stack_api/llama_stack_api/conversations.py +0 -270
  244. llama_stack_api/llama_stack_api/datasetio.py +0 -55
  245. llama_stack_api/llama_stack_api/datasets/__init__.py +0 -61
  246. llama_stack_api/llama_stack_api/datasets/api.py +0 -35
  247. llama_stack_api/llama_stack_api/datasets/fastapi_routes.py +0 -104
  248. llama_stack_api/llama_stack_api/datasets/models.py +0 -152
  249. llama_stack_api/llama_stack_api/datatypes.py +0 -373
  250. llama_stack_api/llama_stack_api/eval.py +0 -137
  251. llama_stack_api/llama_stack_api/file_processors/__init__.py +0 -27
  252. llama_stack_api/llama_stack_api/file_processors/api.py +0 -64
  253. llama_stack_api/llama_stack_api/file_processors/fastapi_routes.py +0 -78
  254. llama_stack_api/llama_stack_api/file_processors/models.py +0 -42
  255. llama_stack_api/llama_stack_api/files/__init__.py +0 -35
  256. llama_stack_api/llama_stack_api/files/api.py +0 -51
  257. llama_stack_api/llama_stack_api/files/fastapi_routes.py +0 -124
  258. llama_stack_api/llama_stack_api/files/models.py +0 -107
  259. llama_stack_api/llama_stack_api/inference.py +0 -1169
  260. llama_stack_api/llama_stack_api/inspect_api/__init__.py +0 -37
  261. llama_stack_api/llama_stack_api/inspect_api/api.py +0 -25
  262. llama_stack_api/llama_stack_api/inspect_api/fastapi_routes.py +0 -76
  263. llama_stack_api/llama_stack_api/inspect_api/models.py +0 -28
  264. llama_stack_api/llama_stack_api/internal/__init__.py +0 -9
  265. llama_stack_api/llama_stack_api/internal/kvstore.py +0 -28
  266. llama_stack_api/llama_stack_api/internal/sqlstore.py +0 -81
  267. llama_stack_api/llama_stack_api/models.py +0 -171
  268. llama_stack_api/llama_stack_api/openai_responses.py +0 -1468
  269. llama_stack_api/llama_stack_api/post_training.py +0 -370
  270. llama_stack_api/llama_stack_api/prompts.py +0 -203
  271. llama_stack_api/llama_stack_api/providers/__init__.py +0 -33
  272. llama_stack_api/llama_stack_api/providers/api.py +0 -16
  273. llama_stack_api/llama_stack_api/providers/fastapi_routes.py +0 -57
  274. llama_stack_api/llama_stack_api/providers/models.py +0 -24
  275. llama_stack_api/llama_stack_api/py.typed +0 -0
  276. llama_stack_api/llama_stack_api/rag_tool.py +0 -168
  277. llama_stack_api/llama_stack_api/resource.py +0 -37
  278. llama_stack_api/llama_stack_api/router_utils.py +0 -160
  279. llama_stack_api/llama_stack_api/safety.py +0 -132
  280. llama_stack_api/llama_stack_api/schema_utils.py +0 -208
  281. llama_stack_api/llama_stack_api/scoring.py +0 -93
  282. llama_stack_api/llama_stack_api/scoring_functions.py +0 -211
  283. llama_stack_api/llama_stack_api/shields.py +0 -93
  284. llama_stack_api/llama_stack_api/tools.py +0 -226
  285. llama_stack_api/llama_stack_api/vector_io.py +0 -941
  286. llama_stack_api/llama_stack_api/vector_stores.py +0 -53
  287. llama_stack_api/llama_stack_api/version.py +0 -9
  288. llama_stack_api/models.py +0 -171
  289. llama_stack_api/openai_responses.py +0 -1468
  290. llama_stack_api/post_training.py +0 -370
  291. llama_stack_api/prompts.py +0 -203
  292. llama_stack_api/providers/__init__.py +0 -33
  293. llama_stack_api/providers/api.py +0 -16
  294. llama_stack_api/providers/fastapi_routes.py +0 -57
  295. llama_stack_api/providers/models.py +0 -24
  296. llama_stack_api/py.typed +0 -0
  297. llama_stack_api/rag_tool.py +0 -168
  298. llama_stack_api/resource.py +0 -37
  299. llama_stack_api/router_utils.py +0 -160
  300. llama_stack_api/safety.py +0 -132
  301. llama_stack_api/schema_utils.py +0 -208
  302. llama_stack_api/scoring.py +0 -93
  303. llama_stack_api/scoring_functions.py +0 -211
  304. llama_stack_api/shields.py +0 -93
  305. llama_stack_api/tools.py +0 -226
  306. llama_stack_api/vector_io.py +0 -941
  307. llama_stack_api/vector_stores.py +0 -53
  308. llama_stack_api/version.py +0 -9
  309. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/WHEEL +0 -0
  310. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/entry_points.txt +0 -0
  311. {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,595 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ import heapq
8
+ import json
9
+ from array import array
10
+ from typing import Any
11
+
12
+ import numpy as np
13
+ import oracledb
14
+ from numpy.typing import NDArray
15
+
16
+ from llama_stack.core.storage.kvstore import kvstore_impl
17
+ from llama_stack.log import get_logger
18
+ from llama_stack.providers.remote.vector_io.oci.config import OCI26aiVectorIOConfig
19
+ from llama_stack.providers.utils.memory.openai_vector_store_mixin import VERSION as OPENAIMIXINVERSION
20
+ from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
21
+ from llama_stack.providers.utils.memory.vector_store import (
22
+ ChunkForDeletion,
23
+ EmbeddingIndex,
24
+ VectorStoreWithIndex,
25
+ )
26
+ from llama_stack.providers.utils.vector_io.vector_utils import (
27
+ WeightedInMemoryAggregator,
28
+ sanitize_collection_name,
29
+ )
30
+ from llama_stack_api import (
31
+ EmbeddedChunk,
32
+ Files,
33
+ Inference,
34
+ InterleavedContent,
35
+ QueryChunksResponse,
36
+ VectorIO,
37
+ VectorStore,
38
+ VectorStoreNotFoundError,
39
+ VectorStoresProtocolPrivate,
40
+ )
41
+ from llama_stack_api.internal.kvstore import KVStore
42
+
43
+ logger = get_logger(name=__name__, category="vector_io::oci26ai")
44
+
45
+ VERSION = "v1"
46
+ VECTOR_DBS_PREFIX = f"vector_stores:oci26ai:{VERSION}::"
47
+ VECTOR_INDEX_PREFIX = f"vector_index:oci26ai:{VERSION}::"
48
+ OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:oci26ai:{OPENAIMIXINVERSION}::"
49
+ OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:oci26ai:{OPENAIMIXINVERSION}::"
50
+ OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:oci26ai:{VERSION}::"
51
+
52
+
53
+ def normalize_embedding(embedding: np.typing.NDArray) -> np.typing.NDArray:
54
+ """
55
+ Normalize an embedding vector to unit length (L2 norm).
56
+
57
+ This is required for cosine similarity to behave correctly.
58
+ """
59
+ if embedding is None:
60
+ raise ValueError("Embedding cannot be None")
61
+
62
+ emb = np.asarray(embedding, dtype=np.float64)
63
+
64
+ norm = np.linalg.norm(emb)
65
+ if norm == 0.0:
66
+ raise ValueError("Cannot normalize zero-length vector")
67
+
68
+ return emb / norm
69
+
70
+
71
+ class OCI26aiIndex(EmbeddingIndex):
72
+ def __init__(
73
+ self,
74
+ connection,
75
+ vector_store: VectorStore,
76
+ consistency_level="Strong",
77
+ kvstore: KVStore | None = None,
78
+ vector_datatype: str = "FLOAT32",
79
+ ):
80
+ self.connection = connection
81
+ self.vector_store = vector_store
82
+ self.table_name = sanitize_collection_name(vector_store.vector_store_id)
83
+ self.dimensions = vector_store.embedding_dimension
84
+ self.consistency_level = consistency_level
85
+ self.kvstore = kvstore
86
+ self.vector_datatype = vector_datatype
87
+
88
+ async def initialize(self) -> None:
89
+ logger.info(f"Attempting to create table: {self.table_name}")
90
+ cursor = self.connection.cursor()
91
+ try:
92
+ # Create table
93
+ create_table_sql = f"""
94
+ CREATE TABLE IF NOT EXISTS {self.table_name} (
95
+ chunk_id VARCHAR2(100) PRIMARY KEY,
96
+ content CLOB,
97
+ vector VECTOR({self.dimensions}, {self.vector_datatype}),
98
+ metadata JSON,
99
+ chunk_metadata JSON
100
+ );
101
+ """
102
+ logger.debug(f"Executing SQL: {create_table_sql}")
103
+ cursor.execute(create_table_sql)
104
+ logger.info(f"Table {self.table_name} created successfully")
105
+
106
+ await self.create_indexes()
107
+ finally:
108
+ cursor.close()
109
+
110
+ async def index_exists(self, index_name: str) -> bool:
111
+ cursor = self.connection.cursor()
112
+ try:
113
+ cursor.execute(
114
+ """
115
+ SELECT COUNT(*)
116
+ FROM USER_INDEXES
117
+ WHERE INDEX_NAME = :index_name
118
+ """,
119
+ index_name=index_name.upper(),
120
+ )
121
+ (count,) = cursor.fetchone()
122
+ return bool(count > 0)
123
+ finally:
124
+ cursor.close()
125
+
126
+ async def create_indexes(self):
127
+ indexes = [
128
+ {
129
+ "name": f"{self.table_name}_content_idx",
130
+ "sql": f"""
131
+ CREATE INDEX {self.table_name}_CONTENT_IDX
132
+ ON {self.table_name}(content)
133
+ INDEXTYPE IS CTXSYS.CONTEXT
134
+ PARAMETERS ('SYNC (EVERY "FREQ=SECONDLY;INTERVAL=5")');
135
+ """,
136
+ },
137
+ {
138
+ "name": f"{self.table_name}_vector_ivf_idx",
139
+ "sql": f"""
140
+ CREATE VECTOR INDEX {self.table_name}_vector_ivf_idx
141
+ ON {self.table_name}(vector)
142
+ ORGANIZATION NEIGHBOR PARTITIONS
143
+ DISTANCE COSINE
144
+ WITH TARGET ACCURACY 95
145
+ """,
146
+ },
147
+ ]
148
+
149
+ for idx in indexes:
150
+ if not await self.index_exists(idx["name"]):
151
+ logger.info(f"Creating index: {idx['name']}")
152
+ cursor = self.connection.cursor()
153
+ try:
154
+ cursor.execute(idx["sql"])
155
+ logger.info(f"Index {idx['name']} created successfully")
156
+ finally:
157
+ cursor.close()
158
+ else:
159
+ logger.info(f"Index {idx['name']} already exists, skipping")
160
+
161
+ async def add_chunks(self, embedded_chunks: list[EmbeddedChunk]):
162
+ array_type = "d" if self.vector_datatype == "FLOAT64" else "f"
163
+ data = []
164
+ for chunk in embedded_chunks:
165
+ chunk_step = chunk.model_dump()
166
+ data.append(
167
+ {
168
+ "chunk_id": chunk.chunk_id,
169
+ "content": chunk.content,
170
+ "vector": array(array_type, normalize_embedding(np.array(chunk.embedding)).astype(float).tolist()),
171
+ "metadata": json.dumps(chunk_step.get("metadata")),
172
+ "chunk_metadata": json.dumps(chunk_step.get("chunk_metadata")),
173
+ }
174
+ )
175
+ cursor = self.connection.cursor()
176
+ try:
177
+ query = f"""
178
+ MERGE INTO {self.table_name} t
179
+ USING (
180
+ SELECT
181
+ :chunk_id AS chunk_id,
182
+ :content AS content,
183
+ :vector AS vector,
184
+ :metadata AS metadata,
185
+ :chunk_metadata AS chunk_metadata
186
+ FROM dual
187
+ ) s
188
+ ON (t.chunk_id = s.chunk_id)
189
+
190
+ WHEN MATCHED THEN
191
+ UPDATE SET
192
+ t.content = s.content,
193
+ t.vector = TO_VECTOR(s.vector, {self.dimensions}, {self.vector_datatype}),
194
+ t.metadata = s.metadata,
195
+ t.chunk_metadata = s.chunk_metadata
196
+
197
+ WHEN NOT MATCHED THEN
198
+ INSERT (chunk_id, content, vector, metadata, chunk_metadata)
199
+ VALUES (s.chunk_id, s.content, TO_VECTOR(s.vector, {self.dimensions}, {self.vector_datatype}), s.metadata, s.chunk_metadata)
200
+ """
201
+ logger.debug(f"query: {query}")
202
+ cursor.executemany(
203
+ query,
204
+ data,
205
+ )
206
+ logger.info("Merge completed successfully")
207
+ except Exception as e:
208
+ logger.error(f"Error inserting chunks into Oracle 26AI table {self.table_name}: {e}")
209
+ raise
210
+ finally:
211
+ cursor.close()
212
+
213
+ async def query_vector(
214
+ self,
215
+ embedding: NDArray,
216
+ k: int,
217
+ score_threshold: float | None,
218
+ ) -> QueryChunksResponse:
219
+ """
220
+ Oracle vector search using COSINE similarity.
221
+ Returns top-k chunks and normalized similarity scores in [0, 1].
222
+ """
223
+ cursor = self.connection.cursor()
224
+
225
+ # Ensure query vector is L2-normalized
226
+ array_type = "d" if self.vector_datatype == "FLOAT64" else "f"
227
+ query_vector = array(array_type, normalize_embedding(np.array(embedding)))
228
+
229
+ query = f"""
230
+ SELECT
231
+ *
232
+ FROM (
233
+ SELECT
234
+ content,
235
+ chunk_id,
236
+ metadata,
237
+ chunk_metadata,
238
+ vector,
239
+ VECTOR_DISTANCE(vector, :query_vector, COSINE) AS score
240
+ FROM {self.table_name}
241
+ )
242
+ """
243
+
244
+ params: dict = {
245
+ "query_vector": query_vector,
246
+ }
247
+
248
+ if score_threshold is not None:
249
+ query += " WHERE score >= :score_threshold"
250
+ params["score_threshold"] = score_threshold
251
+
252
+ query += " ORDER BY score DESC FETCH FIRST :k ROWS ONLY"
253
+ params["k"] = k
254
+
255
+ logger.debug(query)
256
+ logger.debug(query_vector)
257
+ try:
258
+ cursor.execute(query, params)
259
+ results = cursor.fetchall()
260
+
261
+ chunks: list[EmbeddedChunk] = []
262
+ scores: list[float] = []
263
+
264
+ for row in results:
265
+ content, chunk_id, metadata, chunk_metadata, vector, score = row
266
+
267
+ chunk = EmbeddedChunk(
268
+ content=content.read(),
269
+ chunk_id=chunk_id,
270
+ metadata=metadata,
271
+ embedding=vector,
272
+ chunk_metadata=chunk_metadata,
273
+ embedding_model=self.vector_store.embedding_model,
274
+ embedding_dimension=self.vector_store.embedding_dimension,
275
+ )
276
+
277
+ chunks.append(chunk)
278
+ scores.append(float(score))
279
+ logger.debug(f"result count: {len(chunks)}")
280
+ return QueryChunksResponse(chunks=chunks, scores=scores)
281
+
282
+ except Exception as e:
283
+ logger.error("Error querying vector: %s", e)
284
+ raise
285
+
286
+ finally:
287
+ cursor.close()
288
+
289
+ async def query_keyword(self, query_string: str, k: int, score_threshold: float | None) -> QueryChunksResponse:
290
+ cursor = self.connection.cursor()
291
+
292
+ # Build base query
293
+ base_query = f"""
294
+ SELECT
295
+ content,
296
+ chunk_id,
297
+ metadata,
298
+ chunk_metadata,
299
+ vector,
300
+ score / max_score AS score
301
+ FROM (
302
+ SELECT
303
+ content,
304
+ chunk_id,
305
+ metadata,
306
+ chunk_metadata,
307
+ vector,
308
+ SCORE(1) AS score,
309
+ MAX(SCORE(1)) OVER () AS max_score
310
+ FROM {self.table_name}
311
+ WHERE CONTAINS(content, :query_string, 1) > 0
312
+ )
313
+ """
314
+
315
+ params = {"query_string": query_string, "k": k}
316
+
317
+ if score_threshold is not None:
318
+ base_query += " WHERE score >= :score_threshold"
319
+ params["score_threshold"] = score_threshold
320
+
321
+ query = base_query + " ORDER BY score DESC FETCH FIRST :k ROWS ONLY;"
322
+
323
+ logger.debug(query)
324
+
325
+ try:
326
+ cursor.execute(query, params)
327
+ results = cursor.fetchall()
328
+
329
+ chunks = []
330
+ scores = []
331
+ for row in results:
332
+ content, chunk_id, metadata, chunk_metadata, vector, score = row
333
+ chunk = EmbeddedChunk(
334
+ content=content.read(),
335
+ chunk_id=chunk_id,
336
+ metadata=metadata,
337
+ embedding=vector,
338
+ chunk_metadata=chunk_metadata,
339
+ embedding_model=self.vector_store.embedding_model,
340
+ embedding_dimension=self.vector_store.embedding_dimension,
341
+ )
342
+ chunks.append(chunk)
343
+ scores.append(float(score))
344
+ logger.debug(f"result count: {len(chunks)}")
345
+ return QueryChunksResponse(chunks=chunks, scores=scores)
346
+ except Exception as e:
347
+ logger.error(f"Error performing keyword search: {e}")
348
+ raise
349
+ finally:
350
+ cursor.close()
351
+
352
+ async def query_hybrid(
353
+ self,
354
+ embedding: NDArray,
355
+ query_string: str,
356
+ k: int,
357
+ score_threshold: float | None,
358
+ reranker_type: str,
359
+ reranker_params: dict[str, Any] | None = None,
360
+ ) -> QueryChunksResponse:
361
+ """
362
+ Hybrid search combining vector similarity and keyword search using configurable reranking.
363
+
364
+ Args:
365
+ embedding: The query embedding vector
366
+ query_string: The text query for keyword search
367
+ k: Number of results to return
368
+ score_threshold: Minimum similarity score threshold
369
+ reranker_type: Type of reranker to use ("rrf" or "weighted")
370
+ reranker_params: Parameters for the reranker
371
+
372
+ Returns:
373
+ QueryChunksResponse with combined results
374
+ """
375
+ if reranker_params is None:
376
+ reranker_params = {}
377
+
378
+ # Get results from both search methods
379
+ vector_response = await self.query_vector(embedding, k, score_threshold)
380
+ keyword_response = await self.query_keyword(query_string, k, score_threshold)
381
+
382
+ # Convert responses to score dictionaries using chunk_id
383
+ vector_scores = {
384
+ chunk.chunk_id: score for chunk, score in zip(vector_response.chunks, vector_response.scores, strict=False)
385
+ }
386
+ keyword_scores = {
387
+ chunk.chunk_id: score
388
+ for chunk, score in zip(keyword_response.chunks, keyword_response.scores, strict=False)
389
+ }
390
+
391
+ # Combine scores using the reranking utility
392
+ combined_scores = WeightedInMemoryAggregator.combine_search_results(
393
+ vector_scores, keyword_scores, reranker_type, reranker_params
394
+ )
395
+
396
+ # Efficient top-k selection because it only tracks the k best candidates it's seen so far
397
+ top_k_items = heapq.nlargest(k, combined_scores.items(), key=lambda x: x[1])
398
+
399
+ # Filter by score threshold
400
+ filtered_items = [(doc_id, score) for doc_id, score in top_k_items if score >= (score_threshold or 0)]
401
+
402
+ # Create a map of chunk_id to chunk for both responses
403
+ chunk_map = {c.chunk_id: c for c in vector_response.chunks + keyword_response.chunks}
404
+
405
+ # Use the map to look up chunks by their IDs
406
+ chunks = []
407
+ scores = []
408
+ for doc_id, score in filtered_items:
409
+ if doc_id in chunk_map:
410
+ chunks.append(chunk_map[doc_id])
411
+ scores.append(score)
412
+
413
+ return QueryChunksResponse(chunks=chunks, scores=scores)
414
+
415
+ async def delete(self):
416
+ try:
417
+ with self.connection.cursor() as cursor:
418
+ cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}")
419
+ logger.info("Dropped table: {self.table_name}")
420
+ except oracledb.DatabaseError as e:
421
+ logger.error(f"Error dropping table {self.table_name}: {e}")
422
+ raise
423
+
424
+ async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
425
+ chunk_ids = [c.chunk_id for c in chunks_for_deletion]
426
+ cursor = self.connection.cursor()
427
+ try:
428
+ cursor.execute(
429
+ f"""
430
+ DELETE FROM {self.table_name}
431
+ WHERE chunk_id IN ({", ".join([f"'{chunk_id}'" for chunk_id in chunk_ids])})
432
+ """
433
+ )
434
+ except Exception as e:
435
+ logger.error(f"Error deleting chunks from Oracle 26AI table {self.table_name}: {e}")
436
+ raise
437
+ finally:
438
+ cursor.close()
439
+
440
+
441
+ class OCI26aiVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
442
+ def __init__(
443
+ self,
444
+ config: OCI26aiVectorIOConfig,
445
+ inference_api: Inference,
446
+ files_api: Files | None,
447
+ ) -> None:
448
+ super().__init__(inference_api=inference_api, files_api=files_api, kvstore=None)
449
+ self.config = config
450
+ self.cache: dict[str, VectorStoreWithIndex] = {}
451
+ self.pool = None
452
+ self.inference_api = inference_api
453
+ self.vector_store_table = None
454
+
455
+ async def initialize(self) -> None:
456
+ logger.info("Initializing OCI26aiVectorIOAdapter")
457
+ self.kvstore = await kvstore_impl(self.config.persistence)
458
+ await self.initialize_openai_vector_stores()
459
+
460
+ try:
461
+ self.connection = oracledb.connect(
462
+ user=self.config.user,
463
+ password=self.config.password,
464
+ dsn=self.config.conn_str,
465
+ config_dir=self.config.tnsnames_loc,
466
+ wallet_location=self.config.ewallet_pem_loc,
467
+ wallet_password=self.config.ewallet_password,
468
+ expire_time=1, # minutes
469
+ )
470
+ self.connection.autocommit = True
471
+ logger.info("Oracle connection created successfully")
472
+ except Exception as e:
473
+ logger.error(f"Error creating Oracle connection: {e}")
474
+ raise
475
+
476
+ # Load State
477
+ start_key = OPENAI_VECTOR_STORES_PREFIX
478
+ end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
479
+ stored_vector_stores = await self.kvstore.values_in_range(start_key, end_key)
480
+ for vector_store_data in stored_vector_stores:
481
+ vector_store = VectorStore.model_validate_json(vector_store_data)
482
+ logger.info(f"Loading index {vector_store.vector_store_name}: {vector_store.vector_store_id}")
483
+ oci_index = OCI26aiIndex(
484
+ connection=self.connection,
485
+ vector_store=vector_store,
486
+ kvstore=self.kvstore,
487
+ vector_datatype=self.config.vector_datatype,
488
+ )
489
+ await oci_index.initialize()
490
+ index = VectorStoreWithIndex(vector_store, index=oci_index, inference_api=self.inference_api)
491
+ self.cache[vector_store.identifier] = index
492
+
493
+ logger.info(f"Completed loading {len(stored_vector_stores)} indexes")
494
+
495
+ async def shutdown(self) -> None:
496
+ logger.info("Shutting down Oracle connection")
497
+ if self.connection is not None:
498
+ self.connection.close()
499
+ # Clean up mixin resources (file batch tasks)
500
+ await super().shutdown()
501
+
502
+ async def register_vector_store(self, vector_store: VectorStore) -> None:
503
+ if self.kvstore is None:
504
+ raise RuntimeError("KVStore not initialized. Call initialize() before registering vector stores.")
505
+
506
+ # # Save to kvstore for persistence
507
+ key = f"{OPENAI_VECTOR_STORES_PREFIX}{vector_store.identifier}"
508
+ await self.kvstore.set(key=key, value=vector_store.model_dump_json())
509
+
510
+ if isinstance(self.config, OCI26aiVectorIOConfig):
511
+ consistency_level = self.config.consistency_level
512
+ else:
513
+ consistency_level = "Strong"
514
+ oci_index = OCI26aiIndex(
515
+ connection=self.connection,
516
+ vector_store=vector_store,
517
+ consistency_level=consistency_level,
518
+ vector_datatype=self.config.vector_datatype,
519
+ )
520
+ index = VectorStoreWithIndex(
521
+ vector_store=vector_store,
522
+ index=oci_index,
523
+ inference_api=self.inference_api,
524
+ )
525
+ await oci_index.initialize()
526
+ self.cache[vector_store.identifier] = index
527
+
528
+ async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex | None:
529
+ if vector_store_id in self.cache:
530
+ return self.cache[vector_store_id]
531
+
532
+ # Try to load from kvstore
533
+ if self.kvstore is None:
534
+ raise RuntimeError("KVStore not initialized. Call initialize() before using vector stores.")
535
+
536
+ key = f"{OPENAI_VECTOR_STORES_PREFIX}{vector_store_id}"
537
+ vector_store_data = await self.kvstore.get(key)
538
+ if not vector_store_data:
539
+ raise VectorStoreNotFoundError(vector_store_id)
540
+
541
+ vector_store = VectorStore.model_validate_json(vector_store_data)
542
+ index = VectorStoreWithIndex(
543
+ vector_store=vector_store,
544
+ index=OCI26aiIndex(
545
+ connection=self.connection,
546
+ vector_store=vector_store,
547
+ kvstore=self.kvstore,
548
+ vector_datatype=self.config.vector_datatype,
549
+ ),
550
+ inference_api=self.inference_api,
551
+ )
552
+ self.cache[vector_store_id] = index
553
+ return index
554
+
555
+ async def unregister_vector_store(self, vector_store_id: str) -> None:
556
+ # Remove provider index and cache
557
+ if vector_store_id in self.cache:
558
+ await self.cache[vector_store_id].index.delete()
559
+ del self.cache[vector_store_id]
560
+
561
+ # Delete vector DB metadata from KV store
562
+ if self.kvstore is None:
563
+ raise RuntimeError("KVStore not initialized. Call initialize() before unregistering vector stores.")
564
+ await self.kvstore.delete(key=f"{OPENAI_VECTOR_STORES_PREFIX}{vector_store_id}")
565
+
566
+ async def insert_chunks(
567
+ self, vector_store_id: str, chunks: list[EmbeddedChunk], ttl_seconds: int | None = None
568
+ ) -> None:
569
+ index = await self._get_and_cache_vector_store_index(vector_store_id)
570
+ if not index:
571
+ raise VectorStoreNotFoundError(vector_store_id)
572
+
573
+ await index.insert_chunks(chunks)
574
+
575
+ async def query_chunks(
576
+ self, vector_store_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
577
+ ) -> QueryChunksResponse:
578
+ index = await self._get_and_cache_vector_store_index(vector_store_id)
579
+ if not index:
580
+ raise VectorStoreNotFoundError(vector_store_id)
581
+
582
+ if params is None:
583
+ params = {}
584
+ if "embedding_dimensions" not in params:
585
+ params["embedding_dimensions"] = index.vector_store.embedding_dimension
586
+
587
+ return await index.query_chunks(query, params)
588
+
589
+ async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
590
+ """Delete a chunk from a milvus vector store."""
591
+ index = await self._get_and_cache_vector_store_index(store_id)
592
+ if not index:
593
+ raise VectorStoreNotFoundError(store_id)
594
+
595
+ await index.index.delete_chunks(chunks_for_deletion)
@@ -4,14 +4,70 @@
4
4
  # This source code is licensed under the terms described in the LICENSE file in
5
5
  # the root directory of this source tree.
6
6
 
7
- from typing import Any
7
+ from enum import StrEnum
8
+ from typing import Annotated, Any, Literal, Self
8
9
 
9
- from pydantic import BaseModel, Field
10
+ from pydantic import BaseModel, Field, model_validator
10
11
 
11
12
  from llama_stack.core.storage.datatypes import KVStoreReference
12
13
  from llama_stack_api import json_schema_type
13
14
 
14
15
 
16
+ class PGVectorIndexType(StrEnum):
17
+ """Supported pgvector vector index types in Llama Stack."""
18
+
19
+ HNSW = "HNSW"
20
+ IVFFlat = "IVFFlat"
21
+
22
+
23
+ class PGVectorHNSWVectorIndex(BaseModel):
24
+ """Configuration for PGVector HNSW (Hierarchical Navigable Small Worlds) vector index.
25
+ https://github.com/pgvector/pgvector?tab=readme-ov-file#hnsw
26
+ """
27
+
28
+ type: Literal[PGVectorIndexType.HNSW] = PGVectorIndexType.HNSW
29
+ m: int | None = Field(
30
+ gt=0,
31
+ default=16,
32
+ description="PGVector's HNSW index parameter - maximum number of edges each vertex has to its neighboring vertices in the graph",
33
+ )
34
+ ef_construction: int | None = Field(
35
+ gt=0,
36
+ default=64,
37
+ description="PGVector's HNSW index parameter - size of the dynamic candidate list used for graph construction",
38
+ )
39
+
40
+
41
+ class PGVectorIVFFlatVectorIndex(BaseModel):
42
+ """Configuration for PGVector IVFFlat (Inverted File with Flat Compression) vector index.
43
+ https://github.com/pgvector/pgvector?tab=readme-ov-file#ivfflat
44
+ """
45
+
46
+ type: Literal[PGVectorIndexType.IVFFlat] = PGVectorIndexType.IVFFlat
47
+ lists: int | None = Field(
48
+ gt=0, default=100, description="PGVector's IVFFlat index parameter - number of lists index divides vectors into"
49
+ )
50
+ probes: int | None = Field(
51
+ gt=0,
52
+ default=10,
53
+ description="PGVector's IVFFlat index parameter - number of lists index searches through during ANN search",
54
+ )
55
+
56
+ @model_validator(mode="after")
57
+ def validate_probes(self) -> Self:
58
+ if self.probes >= self.lists:
59
+ raise ValueError(
60
+ "probes parameter for PGVector IVFFlat index can't be greater than or equal to the number of lists in the index to allow ANN search."
61
+ )
62
+ return self
63
+
64
+
65
+ PGVectorIndexConfig = Annotated[
66
+ PGVectorHNSWVectorIndex | PGVectorIVFFlatVectorIndex,
67
+ Field(discriminator="type"),
68
+ ]
69
+
70
+
15
71
  @json_schema_type
16
72
  class PGVectorVectorIOConfig(BaseModel):
17
73
  host: str | None = Field(default="localhost")
@@ -19,6 +75,13 @@ class PGVectorVectorIOConfig(BaseModel):
19
75
  db: str | None = Field(default="postgres")
20
76
  user: str | None = Field(default="postgres")
21
77
  password: str | None = Field(default="mysecretpassword")
78
+ distance_metric: Literal["COSINE", "L2", "L1", "INNER_PRODUCT"] | None = Field(
79
+ default="COSINE", description="PGVector distance metric used for vector search in PGVectorIndex"
80
+ )
81
+ vector_index: PGVectorIndexConfig | None = Field(
82
+ default_factory=PGVectorHNSWVectorIndex,
83
+ description="PGVector vector index used for Approximate Nearest Neighbor (ANN) search",
84
+ )
22
85
  persistence: KVStoreReference | None = Field(
23
86
  description="Config for KV store backend (SQLite only for now)", default=None
24
87
  )
@@ -40,6 +103,10 @@ class PGVectorVectorIOConfig(BaseModel):
40
103
  "db": db,
41
104
  "user": user,
42
105
  "password": password,
106
+ "distance_metric": "COSINE",
107
+ "vector_index": PGVectorHNSWVectorIndex(m=16, ef_construction=64).model_dump(
108
+ mode="json", exclude_none=True
109
+ ),
43
110
  "persistence": KVStoreReference(
44
111
  backend="kv_default",
45
112
  namespace="vector_io::pgvector",