ai-parrot 0.17.2__cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (535) hide show
  1. agentui/.prettierrc +15 -0
  2. agentui/QUICKSTART.md +272 -0
  3. agentui/README.md +59 -0
  4. agentui/env.example +16 -0
  5. agentui/jsconfig.json +14 -0
  6. agentui/package-lock.json +4242 -0
  7. agentui/package.json +34 -0
  8. agentui/scripts/postinstall/apply-patches.mjs +260 -0
  9. agentui/src/app.css +61 -0
  10. agentui/src/app.d.ts +13 -0
  11. agentui/src/app.html +12 -0
  12. agentui/src/components/LoadingSpinner.svelte +64 -0
  13. agentui/src/components/ThemeSwitcher.svelte +159 -0
  14. agentui/src/components/index.js +4 -0
  15. agentui/src/lib/api/bots.ts +60 -0
  16. agentui/src/lib/api/chat.ts +22 -0
  17. agentui/src/lib/api/http.ts +25 -0
  18. agentui/src/lib/components/BotCard.svelte +33 -0
  19. agentui/src/lib/components/ChatBubble.svelte +63 -0
  20. agentui/src/lib/components/Toast.svelte +21 -0
  21. agentui/src/lib/config.ts +20 -0
  22. agentui/src/lib/stores/auth.svelte.ts +73 -0
  23. agentui/src/lib/stores/theme.svelte.js +64 -0
  24. agentui/src/lib/stores/toast.svelte.ts +31 -0
  25. agentui/src/lib/utils/conversation.ts +39 -0
  26. agentui/src/routes/+layout.svelte +20 -0
  27. agentui/src/routes/+page.svelte +232 -0
  28. agentui/src/routes/login/+page.svelte +200 -0
  29. agentui/src/routes/talk/[agentId]/+page.svelte +297 -0
  30. agentui/src/routes/talk/[agentId]/+page.ts +7 -0
  31. agentui/static/README.md +1 -0
  32. agentui/svelte.config.js +11 -0
  33. agentui/tailwind.config.ts +53 -0
  34. agentui/tsconfig.json +3 -0
  35. agentui/vite.config.ts +10 -0
  36. ai_parrot-0.17.2.dist-info/METADATA +472 -0
  37. ai_parrot-0.17.2.dist-info/RECORD +535 -0
  38. ai_parrot-0.17.2.dist-info/WHEEL +6 -0
  39. ai_parrot-0.17.2.dist-info/entry_points.txt +2 -0
  40. ai_parrot-0.17.2.dist-info/licenses/LICENSE +21 -0
  41. ai_parrot-0.17.2.dist-info/top_level.txt +6 -0
  42. crew-builder/.prettierrc +15 -0
  43. crew-builder/QUICKSTART.md +259 -0
  44. crew-builder/README.md +113 -0
  45. crew-builder/env.example +17 -0
  46. crew-builder/jsconfig.json +14 -0
  47. crew-builder/package-lock.json +4182 -0
  48. crew-builder/package.json +37 -0
  49. crew-builder/scripts/postinstall/apply-patches.mjs +260 -0
  50. crew-builder/src/app.css +62 -0
  51. crew-builder/src/app.d.ts +13 -0
  52. crew-builder/src/app.html +12 -0
  53. crew-builder/src/components/LoadingSpinner.svelte +64 -0
  54. crew-builder/src/components/ThemeSwitcher.svelte +149 -0
  55. crew-builder/src/components/index.js +9 -0
  56. crew-builder/src/lib/api/bots.ts +60 -0
  57. crew-builder/src/lib/api/chat.ts +80 -0
  58. crew-builder/src/lib/api/client.ts +56 -0
  59. crew-builder/src/lib/api/crew/crew.ts +136 -0
  60. crew-builder/src/lib/api/index.ts +5 -0
  61. crew-builder/src/lib/api/o365/auth.ts +65 -0
  62. crew-builder/src/lib/auth/auth.ts +54 -0
  63. crew-builder/src/lib/components/AgentNode.svelte +43 -0
  64. crew-builder/src/lib/components/BotCard.svelte +33 -0
  65. crew-builder/src/lib/components/ChatBubble.svelte +67 -0
  66. crew-builder/src/lib/components/ConfigPanel.svelte +278 -0
  67. crew-builder/src/lib/components/JsonTreeNode.svelte +76 -0
  68. crew-builder/src/lib/components/JsonViewer.svelte +24 -0
  69. crew-builder/src/lib/components/MarkdownEditor.svelte +48 -0
  70. crew-builder/src/lib/components/ThemeToggle.svelte +36 -0
  71. crew-builder/src/lib/components/Toast.svelte +67 -0
  72. crew-builder/src/lib/components/Toolbar.svelte +157 -0
  73. crew-builder/src/lib/components/index.ts +10 -0
  74. crew-builder/src/lib/config.ts +8 -0
  75. crew-builder/src/lib/stores/auth.svelte.ts +228 -0
  76. crew-builder/src/lib/stores/crewStore.ts +369 -0
  77. crew-builder/src/lib/stores/theme.svelte.js +145 -0
  78. crew-builder/src/lib/stores/toast.svelte.ts +69 -0
  79. crew-builder/src/lib/utils/conversation.ts +39 -0
  80. crew-builder/src/lib/utils/markdown.ts +122 -0
  81. crew-builder/src/lib/utils/talkHistory.ts +47 -0
  82. crew-builder/src/routes/+layout.svelte +20 -0
  83. crew-builder/src/routes/+page.svelte +539 -0
  84. crew-builder/src/routes/agents/+page.svelte +247 -0
  85. crew-builder/src/routes/agents/[agentId]/+page.svelte +288 -0
  86. crew-builder/src/routes/agents/[agentId]/+page.ts +7 -0
  87. crew-builder/src/routes/builder/+page.svelte +204 -0
  88. crew-builder/src/routes/crew/ask/+page.svelte +1052 -0
  89. crew-builder/src/routes/crew/ask/+page.ts +1 -0
  90. crew-builder/src/routes/integrations/o365/+page.svelte +304 -0
  91. crew-builder/src/routes/login/+page.svelte +197 -0
  92. crew-builder/src/routes/talk/[agentId]/+page.svelte +487 -0
  93. crew-builder/src/routes/talk/[agentId]/+page.ts +7 -0
  94. crew-builder/static/README.md +1 -0
  95. crew-builder/svelte.config.js +11 -0
  96. crew-builder/tailwind.config.ts +53 -0
  97. crew-builder/tsconfig.json +3 -0
  98. crew-builder/vite.config.ts +10 -0
  99. mcp_servers/calculator_server.py +309 -0
  100. parrot/__init__.py +27 -0
  101. parrot/__pycache__/__init__.cpython-310.pyc +0 -0
  102. parrot/__pycache__/version.cpython-310.pyc +0 -0
  103. parrot/_version.py +34 -0
  104. parrot/a2a/__init__.py +48 -0
  105. parrot/a2a/client.py +658 -0
  106. parrot/a2a/discovery.py +89 -0
  107. parrot/a2a/mixin.py +257 -0
  108. parrot/a2a/models.py +376 -0
  109. parrot/a2a/server.py +770 -0
  110. parrot/agents/__init__.py +29 -0
  111. parrot/bots/__init__.py +12 -0
  112. parrot/bots/a2a_agent.py +19 -0
  113. parrot/bots/abstract.py +3139 -0
  114. parrot/bots/agent.py +1129 -0
  115. parrot/bots/basic.py +9 -0
  116. parrot/bots/chatbot.py +669 -0
  117. parrot/bots/data.py +1618 -0
  118. parrot/bots/database/__init__.py +5 -0
  119. parrot/bots/database/abstract.py +3071 -0
  120. parrot/bots/database/cache.py +286 -0
  121. parrot/bots/database/models.py +468 -0
  122. parrot/bots/database/prompts.py +154 -0
  123. parrot/bots/database/retries.py +98 -0
  124. parrot/bots/database/router.py +269 -0
  125. parrot/bots/database/sql.py +41 -0
  126. parrot/bots/db/__init__.py +6 -0
  127. parrot/bots/db/abstract.py +556 -0
  128. parrot/bots/db/bigquery.py +602 -0
  129. parrot/bots/db/cache.py +85 -0
  130. parrot/bots/db/documentdb.py +668 -0
  131. parrot/bots/db/elastic.py +1014 -0
  132. parrot/bots/db/influx.py +898 -0
  133. parrot/bots/db/mock.py +96 -0
  134. parrot/bots/db/multi.py +783 -0
  135. parrot/bots/db/prompts.py +185 -0
  136. parrot/bots/db/sql.py +1255 -0
  137. parrot/bots/db/tools.py +212 -0
  138. parrot/bots/document.py +680 -0
  139. parrot/bots/hrbot.py +15 -0
  140. parrot/bots/kb.py +170 -0
  141. parrot/bots/mcp.py +36 -0
  142. parrot/bots/orchestration/README.md +463 -0
  143. parrot/bots/orchestration/__init__.py +1 -0
  144. parrot/bots/orchestration/agent.py +155 -0
  145. parrot/bots/orchestration/crew.py +3330 -0
  146. parrot/bots/orchestration/fsm.py +1179 -0
  147. parrot/bots/orchestration/hr.py +434 -0
  148. parrot/bots/orchestration/storage/__init__.py +4 -0
  149. parrot/bots/orchestration/storage/memory.py +100 -0
  150. parrot/bots/orchestration/storage/mixin.py +119 -0
  151. parrot/bots/orchestration/verify.py +202 -0
  152. parrot/bots/product.py +204 -0
  153. parrot/bots/prompts/__init__.py +96 -0
  154. parrot/bots/prompts/agents.py +155 -0
  155. parrot/bots/prompts/data.py +216 -0
  156. parrot/bots/prompts/output_generation.py +8 -0
  157. parrot/bots/scraper/__init__.py +3 -0
  158. parrot/bots/scraper/models.py +122 -0
  159. parrot/bots/scraper/scraper.py +1173 -0
  160. parrot/bots/scraper/templates.py +115 -0
  161. parrot/bots/stores/__init__.py +5 -0
  162. parrot/bots/stores/local.py +172 -0
  163. parrot/bots/webdev.py +81 -0
  164. parrot/cli.py +17 -0
  165. parrot/clients/__init__.py +16 -0
  166. parrot/clients/base.py +1491 -0
  167. parrot/clients/claude.py +1191 -0
  168. parrot/clients/factory.py +129 -0
  169. parrot/clients/google.py +4567 -0
  170. parrot/clients/gpt.py +1975 -0
  171. parrot/clients/grok.py +432 -0
  172. parrot/clients/groq.py +986 -0
  173. parrot/clients/hf.py +582 -0
  174. parrot/clients/models.py +18 -0
  175. parrot/conf.py +395 -0
  176. parrot/embeddings/__init__.py +9 -0
  177. parrot/embeddings/base.py +157 -0
  178. parrot/embeddings/google.py +98 -0
  179. parrot/embeddings/huggingface.py +74 -0
  180. parrot/embeddings/openai.py +84 -0
  181. parrot/embeddings/processor.py +88 -0
  182. parrot/exceptions.c +13868 -0
  183. parrot/exceptions.cpython-310-x86_64-linux-gnu.so +0 -0
  184. parrot/exceptions.pxd +22 -0
  185. parrot/exceptions.pxi +15 -0
  186. parrot/exceptions.pyx +44 -0
  187. parrot/generators/__init__.py +29 -0
  188. parrot/generators/base.py +200 -0
  189. parrot/generators/html.py +293 -0
  190. parrot/generators/react.py +205 -0
  191. parrot/generators/streamlit.py +203 -0
  192. parrot/generators/template.py +105 -0
  193. parrot/handlers/__init__.py +4 -0
  194. parrot/handlers/agent.py +861 -0
  195. parrot/handlers/agents/__init__.py +1 -0
  196. parrot/handlers/agents/abstract.py +900 -0
  197. parrot/handlers/bots.py +338 -0
  198. parrot/handlers/chat.py +915 -0
  199. parrot/handlers/creation.sql +192 -0
  200. parrot/handlers/crew/ARCHITECTURE.md +362 -0
  201. parrot/handlers/crew/README_BOTMANAGER_PERSISTENCE.md +303 -0
  202. parrot/handlers/crew/README_REDIS_PERSISTENCE.md +366 -0
  203. parrot/handlers/crew/__init__.py +0 -0
  204. parrot/handlers/crew/handler.py +801 -0
  205. parrot/handlers/crew/models.py +229 -0
  206. parrot/handlers/crew/redis_persistence.py +523 -0
  207. parrot/handlers/jobs/__init__.py +10 -0
  208. parrot/handlers/jobs/job.py +384 -0
  209. parrot/handlers/jobs/mixin.py +627 -0
  210. parrot/handlers/jobs/models.py +115 -0
  211. parrot/handlers/jobs/worker.py +31 -0
  212. parrot/handlers/models.py +596 -0
  213. parrot/handlers/o365_auth.py +105 -0
  214. parrot/handlers/stream.py +337 -0
  215. parrot/interfaces/__init__.py +6 -0
  216. parrot/interfaces/aws.py +143 -0
  217. parrot/interfaces/credentials.py +113 -0
  218. parrot/interfaces/database.py +27 -0
  219. parrot/interfaces/google.py +1123 -0
  220. parrot/interfaces/hierarchy.py +1227 -0
  221. parrot/interfaces/http.py +651 -0
  222. parrot/interfaces/images/__init__.py +0 -0
  223. parrot/interfaces/images/plugins/__init__.py +24 -0
  224. parrot/interfaces/images/plugins/abstract.py +58 -0
  225. parrot/interfaces/images/plugins/analisys.py +148 -0
  226. parrot/interfaces/images/plugins/classify.py +150 -0
  227. parrot/interfaces/images/plugins/classifybase.py +182 -0
  228. parrot/interfaces/images/plugins/detect.py +150 -0
  229. parrot/interfaces/images/plugins/exif.py +1103 -0
  230. parrot/interfaces/images/plugins/hash.py +52 -0
  231. parrot/interfaces/images/plugins/vision.py +104 -0
  232. parrot/interfaces/images/plugins/yolo.py +66 -0
  233. parrot/interfaces/images/plugins/zerodetect.py +197 -0
  234. parrot/interfaces/o365.py +978 -0
  235. parrot/interfaces/onedrive.py +822 -0
  236. parrot/interfaces/sharepoint.py +1435 -0
  237. parrot/interfaces/soap.py +257 -0
  238. parrot/loaders/__init__.py +8 -0
  239. parrot/loaders/abstract.py +1131 -0
  240. parrot/loaders/audio.py +199 -0
  241. parrot/loaders/basepdf.py +53 -0
  242. parrot/loaders/basevideo.py +1568 -0
  243. parrot/loaders/csv.py +409 -0
  244. parrot/loaders/docx.py +116 -0
  245. parrot/loaders/epubloader.py +316 -0
  246. parrot/loaders/excel.py +199 -0
  247. parrot/loaders/factory.py +55 -0
  248. parrot/loaders/files/__init__.py +0 -0
  249. parrot/loaders/files/abstract.py +39 -0
  250. parrot/loaders/files/html.py +26 -0
  251. parrot/loaders/files/text.py +63 -0
  252. parrot/loaders/html.py +152 -0
  253. parrot/loaders/markdown.py +442 -0
  254. parrot/loaders/pdf.py +373 -0
  255. parrot/loaders/pdfmark.py +320 -0
  256. parrot/loaders/pdftables.py +506 -0
  257. parrot/loaders/ppt.py +476 -0
  258. parrot/loaders/qa.py +63 -0
  259. parrot/loaders/splitters/__init__.py +10 -0
  260. parrot/loaders/splitters/base.py +138 -0
  261. parrot/loaders/splitters/md.py +228 -0
  262. parrot/loaders/splitters/token.py +143 -0
  263. parrot/loaders/txt.py +26 -0
  264. parrot/loaders/video.py +89 -0
  265. parrot/loaders/videolocal.py +218 -0
  266. parrot/loaders/videounderstanding.py +377 -0
  267. parrot/loaders/vimeo.py +167 -0
  268. parrot/loaders/web.py +599 -0
  269. parrot/loaders/youtube.py +504 -0
  270. parrot/manager/__init__.py +5 -0
  271. parrot/manager/manager.py +1030 -0
  272. parrot/mcp/__init__.py +28 -0
  273. parrot/mcp/adapter.py +105 -0
  274. parrot/mcp/cli.py +174 -0
  275. parrot/mcp/client.py +119 -0
  276. parrot/mcp/config.py +75 -0
  277. parrot/mcp/integration.py +842 -0
  278. parrot/mcp/oauth.py +933 -0
  279. parrot/mcp/server.py +225 -0
  280. parrot/mcp/transports/__init__.py +3 -0
  281. parrot/mcp/transports/base.py +279 -0
  282. parrot/mcp/transports/grpc_session.py +163 -0
  283. parrot/mcp/transports/http.py +312 -0
  284. parrot/mcp/transports/mcp.proto +108 -0
  285. parrot/mcp/transports/quic.py +1082 -0
  286. parrot/mcp/transports/sse.py +330 -0
  287. parrot/mcp/transports/stdio.py +309 -0
  288. parrot/mcp/transports/unix.py +395 -0
  289. parrot/mcp/transports/websocket.py +547 -0
  290. parrot/memory/__init__.py +16 -0
  291. parrot/memory/abstract.py +209 -0
  292. parrot/memory/agent.py +32 -0
  293. parrot/memory/cache.py +175 -0
  294. parrot/memory/core.py +555 -0
  295. parrot/memory/file.py +153 -0
  296. parrot/memory/mem.py +131 -0
  297. parrot/memory/redis.py +613 -0
  298. parrot/models/__init__.py +46 -0
  299. parrot/models/basic.py +118 -0
  300. parrot/models/compliance.py +208 -0
  301. parrot/models/crew.py +395 -0
  302. parrot/models/detections.py +654 -0
  303. parrot/models/generation.py +85 -0
  304. parrot/models/google.py +223 -0
  305. parrot/models/groq.py +23 -0
  306. parrot/models/openai.py +30 -0
  307. parrot/models/outputs.py +285 -0
  308. parrot/models/responses.py +938 -0
  309. parrot/notifications/__init__.py +743 -0
  310. parrot/openapi/__init__.py +3 -0
  311. parrot/openapi/components.yaml +641 -0
  312. parrot/openapi/config.py +322 -0
  313. parrot/outputs/__init__.py +32 -0
  314. parrot/outputs/formats/__init__.py +108 -0
  315. parrot/outputs/formats/altair.py +359 -0
  316. parrot/outputs/formats/application.py +122 -0
  317. parrot/outputs/formats/base.py +351 -0
  318. parrot/outputs/formats/bokeh.py +356 -0
  319. parrot/outputs/formats/card.py +424 -0
  320. parrot/outputs/formats/chart.py +436 -0
  321. parrot/outputs/formats/d3.py +255 -0
  322. parrot/outputs/formats/echarts.py +310 -0
  323. parrot/outputs/formats/generators/__init__.py +0 -0
  324. parrot/outputs/formats/generators/abstract.py +61 -0
  325. parrot/outputs/formats/generators/panel.py +145 -0
  326. parrot/outputs/formats/generators/streamlit.py +86 -0
  327. parrot/outputs/formats/generators/terminal.py +63 -0
  328. parrot/outputs/formats/holoviews.py +310 -0
  329. parrot/outputs/formats/html.py +147 -0
  330. parrot/outputs/formats/jinja2.py +46 -0
  331. parrot/outputs/formats/json.py +87 -0
  332. parrot/outputs/formats/map.py +933 -0
  333. parrot/outputs/formats/markdown.py +172 -0
  334. parrot/outputs/formats/matplotlib.py +237 -0
  335. parrot/outputs/formats/mixins/__init__.py +0 -0
  336. parrot/outputs/formats/mixins/emaps.py +855 -0
  337. parrot/outputs/formats/plotly.py +341 -0
  338. parrot/outputs/formats/seaborn.py +310 -0
  339. parrot/outputs/formats/table.py +397 -0
  340. parrot/outputs/formats/template_report.py +138 -0
  341. parrot/outputs/formats/yaml.py +125 -0
  342. parrot/outputs/formatter.py +152 -0
  343. parrot/outputs/templates/__init__.py +95 -0
  344. parrot/pipelines/__init__.py +0 -0
  345. parrot/pipelines/abstract.py +210 -0
  346. parrot/pipelines/detector.py +124 -0
  347. parrot/pipelines/models.py +90 -0
  348. parrot/pipelines/planogram.py +3002 -0
  349. parrot/pipelines/table.sql +97 -0
  350. parrot/plugins/__init__.py +106 -0
  351. parrot/plugins/importer.py +80 -0
  352. parrot/py.typed +0 -0
  353. parrot/registry/__init__.py +18 -0
  354. parrot/registry/registry.py +594 -0
  355. parrot/scheduler/__init__.py +1189 -0
  356. parrot/scheduler/models.py +60 -0
  357. parrot/security/__init__.py +16 -0
  358. parrot/security/prompt_injection.py +268 -0
  359. parrot/security/security_events.sql +25 -0
  360. parrot/services/__init__.py +1 -0
  361. parrot/services/mcp/__init__.py +8 -0
  362. parrot/services/mcp/config.py +13 -0
  363. parrot/services/mcp/server.py +295 -0
  364. parrot/services/o365_remote_auth.py +235 -0
  365. parrot/stores/__init__.py +7 -0
  366. parrot/stores/abstract.py +352 -0
  367. parrot/stores/arango.py +1090 -0
  368. parrot/stores/bigquery.py +1377 -0
  369. parrot/stores/cache.py +106 -0
  370. parrot/stores/empty.py +10 -0
  371. parrot/stores/faiss_store.py +1157 -0
  372. parrot/stores/kb/__init__.py +9 -0
  373. parrot/stores/kb/abstract.py +68 -0
  374. parrot/stores/kb/cache.py +165 -0
  375. parrot/stores/kb/doc.py +325 -0
  376. parrot/stores/kb/hierarchy.py +346 -0
  377. parrot/stores/kb/local.py +457 -0
  378. parrot/stores/kb/prompt.py +28 -0
  379. parrot/stores/kb/redis.py +659 -0
  380. parrot/stores/kb/store.py +115 -0
  381. parrot/stores/kb/user.py +374 -0
  382. parrot/stores/models.py +59 -0
  383. parrot/stores/pgvector.py +3 -0
  384. parrot/stores/postgres.py +2853 -0
  385. parrot/stores/utils/__init__.py +0 -0
  386. parrot/stores/utils/chunking.py +197 -0
  387. parrot/telemetry/__init__.py +3 -0
  388. parrot/telemetry/mixin.py +111 -0
  389. parrot/template/__init__.py +3 -0
  390. parrot/template/engine.py +259 -0
  391. parrot/tools/__init__.py +23 -0
  392. parrot/tools/abstract.py +644 -0
  393. parrot/tools/agent.py +363 -0
  394. parrot/tools/arangodbsearch.py +537 -0
  395. parrot/tools/arxiv_tool.py +188 -0
  396. parrot/tools/calculator/__init__.py +3 -0
  397. parrot/tools/calculator/operations/__init__.py +38 -0
  398. parrot/tools/calculator/operations/calculus.py +80 -0
  399. parrot/tools/calculator/operations/statistics.py +76 -0
  400. parrot/tools/calculator/tool.py +150 -0
  401. parrot/tools/cloudwatch.py +988 -0
  402. parrot/tools/codeinterpreter/__init__.py +127 -0
  403. parrot/tools/codeinterpreter/executor.py +371 -0
  404. parrot/tools/codeinterpreter/internals.py +473 -0
  405. parrot/tools/codeinterpreter/models.py +643 -0
  406. parrot/tools/codeinterpreter/prompts.py +224 -0
  407. parrot/tools/codeinterpreter/tool.py +664 -0
  408. parrot/tools/company_info/__init__.py +6 -0
  409. parrot/tools/company_info/tool.py +1138 -0
  410. parrot/tools/correlationanalysis.py +437 -0
  411. parrot/tools/database/abstract.py +286 -0
  412. parrot/tools/database/bq.py +115 -0
  413. parrot/tools/database/cache.py +284 -0
  414. parrot/tools/database/models.py +95 -0
  415. parrot/tools/database/pg.py +343 -0
  416. parrot/tools/databasequery.py +1159 -0
  417. parrot/tools/db.py +1800 -0
  418. parrot/tools/ddgo.py +370 -0
  419. parrot/tools/decorators.py +271 -0
  420. parrot/tools/dftohtml.py +282 -0
  421. parrot/tools/document.py +549 -0
  422. parrot/tools/ecs.py +819 -0
  423. parrot/tools/edareport.py +368 -0
  424. parrot/tools/elasticsearch.py +1049 -0
  425. parrot/tools/employees.py +462 -0
  426. parrot/tools/epson/__init__.py +96 -0
  427. parrot/tools/excel.py +683 -0
  428. parrot/tools/file/__init__.py +13 -0
  429. parrot/tools/file/abstract.py +76 -0
  430. parrot/tools/file/gcs.py +378 -0
  431. parrot/tools/file/local.py +284 -0
  432. parrot/tools/file/s3.py +511 -0
  433. parrot/tools/file/tmp.py +309 -0
  434. parrot/tools/file/tool.py +501 -0
  435. parrot/tools/file_reader.py +129 -0
  436. parrot/tools/flowtask/__init__.py +19 -0
  437. parrot/tools/flowtask/tool.py +761 -0
  438. parrot/tools/gittoolkit.py +508 -0
  439. parrot/tools/google/__init__.py +18 -0
  440. parrot/tools/google/base.py +169 -0
  441. parrot/tools/google/tools.py +1251 -0
  442. parrot/tools/googlelocation.py +5 -0
  443. parrot/tools/googleroutes.py +5 -0
  444. parrot/tools/googlesearch.py +5 -0
  445. parrot/tools/googlesitesearch.py +5 -0
  446. parrot/tools/googlevoice.py +2 -0
  447. parrot/tools/gvoice.py +695 -0
  448. parrot/tools/ibisworld/README.md +225 -0
  449. parrot/tools/ibisworld/__init__.py +11 -0
  450. parrot/tools/ibisworld/tool.py +366 -0
  451. parrot/tools/jiratoolkit.py +1718 -0
  452. parrot/tools/manager.py +1098 -0
  453. parrot/tools/math.py +152 -0
  454. parrot/tools/metadata.py +476 -0
  455. parrot/tools/msteams.py +1621 -0
  456. parrot/tools/msword.py +635 -0
  457. parrot/tools/multidb.py +580 -0
  458. parrot/tools/multistoresearch.py +369 -0
  459. parrot/tools/networkninja.py +167 -0
  460. parrot/tools/nextstop/__init__.py +4 -0
  461. parrot/tools/nextstop/base.py +286 -0
  462. parrot/tools/nextstop/employee.py +733 -0
  463. parrot/tools/nextstop/store.py +462 -0
  464. parrot/tools/notification.py +435 -0
  465. parrot/tools/o365/__init__.py +42 -0
  466. parrot/tools/o365/base.py +295 -0
  467. parrot/tools/o365/bundle.py +522 -0
  468. parrot/tools/o365/events.py +554 -0
  469. parrot/tools/o365/mail.py +992 -0
  470. parrot/tools/o365/onedrive.py +497 -0
  471. parrot/tools/o365/sharepoint.py +641 -0
  472. parrot/tools/openapi_toolkit.py +904 -0
  473. parrot/tools/openweather.py +527 -0
  474. parrot/tools/pdfprint.py +1001 -0
  475. parrot/tools/powerbi.py +518 -0
  476. parrot/tools/powerpoint.py +1113 -0
  477. parrot/tools/pricestool.py +146 -0
  478. parrot/tools/products/__init__.py +246 -0
  479. parrot/tools/prophet_tool.py +171 -0
  480. parrot/tools/pythonpandas.py +630 -0
  481. parrot/tools/pythonrepl.py +910 -0
  482. parrot/tools/qsource.py +436 -0
  483. parrot/tools/querytoolkit.py +395 -0
  484. parrot/tools/quickeda.py +827 -0
  485. parrot/tools/resttool.py +553 -0
  486. parrot/tools/retail/__init__.py +0 -0
  487. parrot/tools/retail/bby.py +528 -0
  488. parrot/tools/sandboxtool.py +703 -0
  489. parrot/tools/sassie/__init__.py +352 -0
  490. parrot/tools/scraping/__init__.py +7 -0
  491. parrot/tools/scraping/docs/select.md +466 -0
  492. parrot/tools/scraping/documentation.md +1278 -0
  493. parrot/tools/scraping/driver.py +436 -0
  494. parrot/tools/scraping/models.py +576 -0
  495. parrot/tools/scraping/options.py +85 -0
  496. parrot/tools/scraping/orchestrator.py +517 -0
  497. parrot/tools/scraping/readme.md +740 -0
  498. parrot/tools/scraping/tool.py +3115 -0
  499. parrot/tools/seasonaldetection.py +642 -0
  500. parrot/tools/shell_tool/__init__.py +5 -0
  501. parrot/tools/shell_tool/actions.py +408 -0
  502. parrot/tools/shell_tool/engine.py +155 -0
  503. parrot/tools/shell_tool/models.py +322 -0
  504. parrot/tools/shell_tool/tool.py +442 -0
  505. parrot/tools/site_search.py +214 -0
  506. parrot/tools/textfile.py +418 -0
  507. parrot/tools/think.py +378 -0
  508. parrot/tools/toolkit.py +298 -0
  509. parrot/tools/webapp_tool.py +187 -0
  510. parrot/tools/whatif.py +1279 -0
  511. parrot/tools/workday/MULTI_WSDL_EXAMPLE.md +249 -0
  512. parrot/tools/workday/__init__.py +6 -0
  513. parrot/tools/workday/models.py +1389 -0
  514. parrot/tools/workday/tool.py +1293 -0
  515. parrot/tools/yfinance_tool.py +306 -0
  516. parrot/tools/zipcode.py +217 -0
  517. parrot/utils/__init__.py +2 -0
  518. parrot/utils/helpers.py +73 -0
  519. parrot/utils/parsers/__init__.py +5 -0
  520. parrot/utils/parsers/toml.c +12078 -0
  521. parrot/utils/parsers/toml.cpython-310-x86_64-linux-gnu.so +0 -0
  522. parrot/utils/parsers/toml.pyx +21 -0
  523. parrot/utils/toml.py +11 -0
  524. parrot/utils/types.cpp +20936 -0
  525. parrot/utils/types.cpython-310-x86_64-linux-gnu.so +0 -0
  526. parrot/utils/types.pyx +213 -0
  527. parrot/utils/uv.py +11 -0
  528. parrot/version.py +10 -0
  529. parrot/yaml-rs/Cargo.lock +350 -0
  530. parrot/yaml-rs/Cargo.toml +19 -0
  531. parrot/yaml-rs/pyproject.toml +19 -0
  532. parrot/yaml-rs/python/yaml_rs/__init__.py +81 -0
  533. parrot/yaml-rs/src/lib.rs +222 -0
  534. requirements/docker-compose.yml +24 -0
  535. requirements/requirements-dev.txt +21 -0
@@ -0,0 +1,2853 @@
1
+ from typing import Any, Dict, List, Optional, Union, Callable
2
+ import uuid
3
+ from contextlib import asynccontextmanager
4
+ import numpy as np
5
+ import sqlalchemy
6
+ from sqlalchemy import (
7
+ text,
8
+ Column,
9
+ insert,
10
+ Table,
11
+ MetaData,
12
+ select,
13
+ asc,
14
+ func,
15
+ event,
16
+ JSON,
17
+ Index
18
+ )
19
+ from sqlalchemy.sql import literal_column
20
+ from sqlalchemy import bindparam
21
+ from sqlalchemy.orm import aliased
22
+ from sqlalchemy.ext.asyncio import (
23
+ create_async_engine,
24
+ AsyncSession,
25
+ AsyncEngine,
26
+ async_sessionmaker
27
+ )
28
+ from sqlalchemy.sql.expression import cast
29
+ from sqlalchemy.dialects.postgresql import JSONB, ARRAY
30
+ from sqlalchemy.orm import (
31
+ declarative_base,
32
+ DeclarativeBase,
33
+ Mapped,
34
+ mapped_column
35
+ )
36
+ # PgVector
37
+ from pgvector.sqlalchemy import Vector
38
+ from pgvector.asyncpg import register_vector
39
+ from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances
40
+ # Datamodel
41
+ from datamodel.parsers.json import json_encoder # pylint: disable=E0611
42
+ from navconfig.logging import logging
43
+ from .abstract import AbstractStore
44
+ from ..conf import default_sqlalchemy_pg
45
+ from .models import SearchResult, Document, DistanceStrategy
46
+ from .utils.chunking import LateChunkingProcessor
47
+
48
+
49
+ def vector_distance(embedding_column, vector, op):
50
+ return text(f"{embedding_column} {op} :query_embedding").label("distance")
51
+
52
+ class Base(DeclarativeBase):
53
+ pass
54
+
55
+ class PgVectorStore(AbstractStore):
56
+ """
57
+ A PostgreSQL vector store implementation using pgvector, completely independent of Langchain.
58
+ This store interacts directly with a specified schema and table for robust data isolation.
59
+ """
60
+ def __init__(
61
+ self,
62
+ table: str = None,
63
+ schema: str = 'public',
64
+ id_column: str = 'id',
65
+ embedding_column: str = 'embedding',
66
+ document_column: str = 'document',
67
+ text_column: str = 'text',
68
+ embedding_model: Union[dict, str] = "sentence-transformers/all-mpnet-base-v2",
69
+ embedding: Optional[Callable] = None,
70
+ distance_strategy: DistanceStrategy = DistanceStrategy.COSINE,
71
+ use_uuid: bool = False,
72
+ pool_size: int = 50,
73
+ auto_initialize: bool = True,
74
+ **kwargs
75
+ ):
76
+ """ Initializes the PgVectorStore with the specified parameters.
77
+ """
78
+ self.table_name = table
79
+ self.schema = schema
80
+ self._id_column: str = id_column
81
+ self._embedding_column: str = embedding_column
82
+ self._document_column: str = document_column
83
+ self._text_column: str = text_column
84
+ self.distance_strategy = distance_strategy
85
+ self._use_uuid: bool = use_uuid
86
+ self._embedding_store_cache: Dict[str, Any] = {}
87
+ self._max_size = pool_size or 50
88
+ self._auto_initialize_db: bool = auto_initialize
89
+ super().__init__(
90
+ embedding_model=embedding_model,
91
+ embedding=embedding,
92
+ **kwargs
93
+ )
94
+ self.dsn = kwargs.get('dsn', default_sqlalchemy_pg)
95
+ self._connection: Optional[AsyncEngine] = None
96
+ self._session_factory: Optional[async_sessionmaker] = None
97
+ self._session: Optional[AsyncSession] = None
98
+ self.logger = logging.getLogger("PgVectorStore")
99
+ self.embedding_store = None
100
+ if table:
101
+ # create a table definition:
102
+ self.embedding_store = self._define_collection_store(
103
+ table=table,
104
+ schema=schema,
105
+ dimension=self.dimension,
106
+ id_column=id_column,
107
+ embedding_column=embedding_column,
108
+ document_column=self._document_column,
109
+ text_column=text_column,
110
+ )
111
+
112
+ def get_id_column(self, use_uuid: bool) -> sqlalchemy.Column:
113
+ """
114
+ Return the ID column definition based on whether to use UUID or not.
115
+ If use_uuid is True, the ID column will be a PostgreSQL UUID type with
116
+ server-side generation using uuid_generate_v4().
117
+ If use_uuid is False, the ID column will be a String type with a default
118
+ value generated by Python's uuid.uuid4() function.
119
+ """
120
+ if use_uuid:
121
+ # DB will auto-generate UUID; SQLAlchemy should not set a default!
122
+ return sqlalchemy.Column(
123
+ sqlalchemy.dialects.postgresql.UUID(as_uuid=True),
124
+ primary_key=True,
125
+ index=True,
126
+ unique=True,
127
+ server_default=sqlalchemy.text('uuid_generate_v4()')
128
+ )
129
+ else:
130
+ # Python generates UUID (as string)
131
+ return sqlalchemy.Column(
132
+ sqlalchemy.String,
133
+ primary_key=True,
134
+ index=True,
135
+ unique=True,
136
+ default=lambda: str(uuid.uuid4())
137
+ )
138
+
139
+ def _define_collection_store(
140
+ self,
141
+ table: str,
142
+ schema: str,
143
+ dimension: int = 384,
144
+ id_column: str = 'id',
145
+ embedding_column: str = 'embedding',
146
+ document_column: str = 'document',
147
+ metadata_column: str = 'cmetadata',
148
+ text_column: str = 'text',
149
+ store_name: str = 'EmbeddingStore',
150
+ colbert_dimension: int = 128 # ColBERT token dimension
151
+ ) -> Any:
152
+ """Dynamically define a SQLAlchemy Table for pgvector storage.
153
+
154
+ Args:
155
+ table: The name of the table to create.
156
+ schema: The schema in which to create the table.
157
+ dimension: The dimensionality of the vector embeddings.
158
+ """
159
+ fq_table_name = f"{schema}.{table}"
160
+ if fq_table_name in self._embedding_store_cache:
161
+ return self._embedding_store_cache[fq_table_name]
162
+
163
+ self.logger.notice(
164
+ f"Defining dynamic ORM class for table {fq_table_name} with dimension {dimension}"
165
+ )
166
+ table_args = {
167
+ "schema": schema,
168
+ "extend_existing": True
169
+ }
170
+ attrs = {
171
+ '__tablename__': table,
172
+ '__table_args__': table_args,
173
+ # id_column: self.get_id_column(use_uuid=self._use_uuid),
174
+ id_column: mapped_column(
175
+ sqlalchemy.String if not self._use_uuid else sqlalchemy.dialects.postgresql.UUID(as_uuid=True),
176
+ primary_key=True,
177
+ index=True,
178
+ unique=True,
179
+ default=lambda: str(uuid.uuid4()) if not self._use_uuid else None,
180
+ server_default=sqlalchemy.text('uuid_generate_v4()') if self._use_uuid else None
181
+ ),
182
+ embedding_column: mapped_column(Vector(dimension)),
183
+ text_column: mapped_column(sqlalchemy.String, nullable=True),
184
+ document_column: mapped_column(sqlalchemy.String, nullable=True),
185
+ metadata_column: mapped_column(JSONB, nullable=True),
186
+
187
+ # embedding_column: Column(Vector(dimension)),
188
+ # text_column: Column(sqlalchemy.String, nullable=True),
189
+ # document_column: Column(sqlalchemy.String, nullable=True),
190
+ # metadata_column: Column(JSONB, nullable=True),
191
+ # ColBERT columns
192
+ # 'token_embeddings': Column(ARRAY(Vector(colbert_dimension)), nullable=True),
193
+ # 'num_tokens': Column(sqlalchemy.Integer, nullable=True),
194
+ # 'collection_id': Column(
195
+ # sqlalchemy.dialects.postgresql.UUID(as_uuid=True),
196
+ # index=True,
197
+ # unique=True,
198
+ # default=uuid.uuid4,
199
+ # server_default=sqlalchemy.text('uuid_generate_v4()')
200
+ # )
201
+ 'token_embeddings': mapped_column(ARRAY(Vector(colbert_dimension)), nullable=True),
202
+ 'num_tokens': mapped_column(sqlalchemy.Integer, nullable=True),
203
+ 'collection_id': mapped_column(
204
+ sqlalchemy.dialects.postgresql.UUID(as_uuid=True),
205
+ index=True,
206
+ unique=True,
207
+ default=uuid.uuid4,
208
+ server_default=sqlalchemy.text('uuid_generate_v4()')
209
+ )
210
+ }
211
+
212
+ # Create dynamic ORM class
213
+ EmbeddingStore = type(store_name, (Base,), attrs)
214
+ EmbeddingStore.__name__ = store_name
215
+ EmbeddingStore.__qualname__ = store_name
216
+
217
+ # Cache the store
218
+ self._embedding_store_cache[fq_table_name] = EmbeddingStore
219
+ self.logger.debug(
220
+ f"Created dynamic ORM class {store_name} for table {fq_table_name}"
221
+ )
222
+
223
+ return EmbeddingStore
224
+
225
+ def define_collection_table(
226
+ self,
227
+ table: str,
228
+ schema: str,
229
+ dimension: int = 384,
230
+ metadata: Optional[MetaData] = None,
231
+ use_uuid: bool = False,
232
+ id_column: str = 'id',
233
+ embedding_column: str = 'embedding'
234
+ ) -> sqlalchemy.Table:
235
+ """Dynamically define a SQLAlchemy Table for pgvector storage."""
236
+ columns = []
237
+
238
+ if use_uuid:
239
+ columns.append(Column(
240
+ id_column,
241
+ sqlalchemy.dialects.postgresql.UUID(as_uuid=True),
242
+ primary_key=True,
243
+ server_default=sqlalchemy.text("uuid_generate_v4()")
244
+ ))
245
+ else:
246
+ columns.append(Column(
247
+ id_column,
248
+ sqlalchemy.String,
249
+ primary_key=True,
250
+ default=lambda: str(uuid.uuid4())
251
+ ))
252
+
253
+ columns.extend([
254
+ Column(embedding_column, Vector(dimension)),
255
+ Column('text', sqlalchemy.String, nullable=True),
256
+ Column('document', sqlalchemy.String, nullable=True),
257
+ Column('cmetadata', JSONB, nullable=True)
258
+ ])
259
+
260
+ return Table(
261
+ table,
262
+ metadata,
263
+ *columns,
264
+ schema=schema
265
+ )
266
+
267
+ async def connection(self, dsn: str = None) -> AsyncEngine:
268
+ """Establishes and returns an async database connection."""
269
+ if self._connection is not None:
270
+ return self._connection
271
+ if not dsn:
272
+ dsn = self.dsn or default_sqlalchemy_pg
273
+ try:
274
+ self._connection = create_async_engine(
275
+ dsn,
276
+ future=True,
277
+ pool_size=self._max_size, # High concurrency support
278
+ max_overflow=100, # Burst capacity
279
+ pool_pre_ping=True, # Connection health checks
280
+ pool_recycle=3600, # Prevent stale connections (1 hour)
281
+ pool_timeout=30, # Wait max 30s for connection
282
+ connect_args={
283
+ "server_settings": {
284
+ "jit": "off", # Disable JIT for vector queries
285
+ "random_page_cost": "1.1", # SSD optimization
286
+ "effective_cache_size": "24GB", # Memory configuration
287
+ "work_mem": "256MB"
288
+ }
289
+ }
290
+ )
291
+ # @event.listens_for(self._connection.sync_engine, "first_connect")
292
+ # def connect(dbapi_connection, connection_record):
293
+ # dbapi_connection.run_async(register_vector)
294
+
295
+ # Create session factory
296
+ self._session_factory = async_sessionmaker(
297
+ bind=self._connection,
298
+ class_=AsyncSession,
299
+ expire_on_commit=False,
300
+ autoflush=False, # Manual control over flushing
301
+ autocommit=False
302
+ )
303
+ if self._auto_initialize_db:
304
+ await self.initialize_database()
305
+ self._connected = True
306
+ self.logger.info(
307
+ "Successfully connected to PostgreSQL."
308
+ )
309
+ except Exception as e:
310
+ self.logger.error(
311
+ f"Failed to connect to PostgreSQL: {e}"
312
+ )
313
+ self._connected = False
314
+ raise
315
+
316
+ async def get_session(self) -> AsyncSession:
317
+ """Get a session from the pool. This is the main method for getting connections."""
318
+ if not self._connection:
319
+ await self.connection()
320
+
321
+ if not self._session_factory:
322
+ raise RuntimeError("Session factory not initialized")
323
+
324
+ return self._session_factory()
325
+
326
+ @asynccontextmanager
327
+ async def session(self):
328
+ """
329
+ Context manager for handling database sessions with proper cleanup.
330
+ This is the recommended way to handle database operations.
331
+
332
+ Usage:
333
+ async with store.session() as session:
334
+ result = await session.execute(stmt)
335
+ await session.commit() # if needed
336
+ """
337
+ if not self._connection:
338
+ await self.connection()
339
+
340
+ session = await self.get_session()
341
+ try:
342
+ yield session
343
+ # Auto-commit if no exception occurred
344
+ if session.in_transaction():
345
+ await session.commit()
346
+ except Exception:
347
+ # Auto-rollback on exception
348
+ if session.in_transaction():
349
+ await session.rollback()
350
+ raise
351
+ finally:
352
+ # Always close session (returns connection to pool)
353
+ await session.close()
354
+
355
+ async def initialize_database(self):
356
+ """Initialize with PgVector 0.8.0+ optimizations"""
357
+ try:
358
+ async with self.session() as session:
359
+ await session.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
360
+
361
+ # Enable iterative scanning (breakthrough feature)
362
+ await session.execute(text("SET hnsw.iterative_scan = 'relaxed_order'"))
363
+ await session.execute(text("SET hnsw.max_scan_tuples = 20000"))
364
+ await session.execute(text("SET hnsw.ef_search = 200"))
365
+ await session.execute(text("SET ivfflat.iterative_scan = 'on'"))
366
+ await session.execute(text("SET ivfflat.max_probes = 100"))
367
+
368
+ # Performance tuning
369
+ await session.execute(text("SET maintenance_work_mem = '2GB'"))
370
+ await session.execute(text("SET max_parallel_maintenance_workers = 8"))
371
+ await session.execute(text("SET enable_seqscan = off"))
372
+
373
+ # Create ColBERT MaxSim function
374
+ await self._create_maxsim_function(session)
375
+
376
+ await session.commit()
377
+ except Exception as e:
378
+ self.logger.warning(f"⚠️ Database auto-initialization failed: {e}")
379
+ # Don't raise - let the engine continue to work
380
+
381
+ async def _create_maxsim_function(self, session):
382
+ """Create the MaxSim function for ColBERT late interaction scoring."""
383
+ maxsim_function = text("""
384
+ CREATE OR REPLACE FUNCTION max_sim(document vector[], query vector[])
385
+ RETURNS double precision AS $$
386
+ DECLARE
387
+ query_vec vector;
388
+ doc_vec vector;
389
+ max_similarity double precision;
390
+ total_score double precision := 0;
391
+ similarity double precision;
392
+ BEGIN
393
+ -- For each query token, find the maximum similarity with any document token
394
+ FOR i IN 1..array_length(query, 1) LOOP
395
+ query_vec := query[i];
396
+ max_similarity := -1;
397
+
398
+ -- Find max similarity with all document tokens
399
+ FOR j IN 1..array_length(document, 1) LOOP
400
+ doc_vec := document[j];
401
+ similarity := 1 - (query_vec <=> doc_vec); -- Convert distance to similarity
402
+
403
+ IF similarity > max_similarity THEN
404
+ max_similarity := similarity;
405
+ END IF;
406
+ END LOOP;
407
+
408
+ -- Add the maximum similarity for this query token
409
+ total_score := total_score + max_similarity;
410
+ END LOOP;
411
+
412
+ RETURN total_score;
413
+ END;
414
+ $$ LANGUAGE plpgsql IMMUTABLE STRICT;
415
+ """)
416
+
417
+ await session.execute(maxsim_function)
418
+ self.logger.info("✅ Created ColBERT MaxSim function")
419
+
420
+
421
+ # Async Context Manager - improved pattern
422
+ async def __aenter__(self):
423
+ """
424
+ Context manager entry. Ensures engine is initialized and manages session lifecycle.
425
+ """
426
+ if not self._connection:
427
+ await self.connection()
428
+
429
+ # Create a session for this context if we don't have one
430
+ if self._session is None:
431
+ self._session = await self.get_session()
432
+
433
+ self._context_depth += 1
434
+ return self
435
+
436
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
437
+ """
438
+ Context manager exit. Properly handles session cleanup.
439
+ """
440
+ self._context_depth -= 1
441
+
442
+ # Only close session when we exit the outermost context
443
+ if self._context_depth == 0 and self._session:
444
+ try:
445
+ if exc_type is not None:
446
+ # Exception occurred, rollback
447
+ if self._session.in_transaction():
448
+ await self._session.rollback()
449
+ else:
450
+ # No exception, commit if in transaction
451
+ if self._session.in_transaction():
452
+ await self._session.commit()
453
+ finally:
454
+ # Always close the session (returns connection to pool)
455
+ await self._session.close()
456
+ self._session = None
457
+
458
+ async def _free_resources(self):
459
+ """Clean up resources but keep the engine/pool available."""
460
+ if self._embed_:
461
+ self._embed_.free()
462
+ self._embed_ = None
463
+
464
+ # Close current session if exists
465
+ if self._session:
466
+ await self._session.close()
467
+ self._session = None
468
+
469
+ async def disconnect(self) -> None:
470
+ """
471
+ Completely dispose of the engine and close all connections.
472
+ Call this when you're completely done with the store.
473
+ """
474
+ # Close current session first
475
+ if self._session:
476
+ await self._session.close()
477
+ self._session = None
478
+
479
+ # Dispose of the engine (closes all pooled connections)
480
+ if self._connection:
481
+ await self._connection.dispose()
482
+ self._connection = None
483
+ self._connected = False
484
+ self._session_factory = None
485
+ self.logger.info(
486
+ "🔌 PostgreSQL engine disposed and all connections closed"
487
+ )
488
+
489
+ async def add_documents(
490
+ self,
491
+ documents: List[Document],
492
+ table: str = None,
493
+ schema: str = None,
494
+ embedding_column: str = 'embedding',
495
+ content_column: str = 'document',
496
+ metadata_column: str = 'cmetadata',
497
+ **kwargs
498
+ ) -> None:
499
+ """
500
+ Embeds and adds documents to the specified table.
501
+
502
+ Args:
503
+ documents: A list of Document objects to add.
504
+ table: The name of the table.
505
+ schema: The database schema where the table resides.
506
+ embedding_column: The name of the column to store embeddings.
507
+ content_column: The name of the column to store the main text content.
508
+ metadata_column: The name of the JSONB column for metadata.
509
+ """
510
+ if not self._connected:
511
+ await self.connection()
512
+
513
+ if not table:
514
+ table = self.table_name
515
+ if not schema:
516
+ schema = self.schema
517
+
518
+ texts = [doc.page_content for doc in documents]
519
+ embeddings = self._embed_.embed_documents(texts)
520
+ metadatas = [doc.metadata for doc in documents]
521
+
522
+ # Step 1: Ensure the ORM table is initialized
523
+ if self.embedding_store is None:
524
+ self.embedding_store = self._define_collection_store(
525
+ table=table,
526
+ schema=schema,
527
+ dimension=self.dimension,
528
+ id_column=self._id_column,
529
+ embedding_column=embedding_column,
530
+ document_column=content_column,
531
+ metadata_column=metadata_column,
532
+ text_column=self._text_column,
533
+ )
534
+
535
+ # Step 2: Prepare values for bulk insert
536
+ values = [
537
+ {
538
+ self._id_column: str(uuid.uuid4()),
539
+ embedding_column: embeddings[i].tolist() if isinstance(
540
+ embeddings[i], np.ndarray
541
+ ) else embeddings[i],
542
+ content_column: texts[i],
543
+ metadata_column: metadatas[i] or {}
544
+ }
545
+ for i in range(len(documents))
546
+ ]
547
+
548
+ # Step 3: Build insert statement using SQLAlchemy's insert()
549
+ insert_stmt = insert(self.embedding_store)
550
+
551
+ # Step 4: Execute using async executemany
552
+ try:
553
+ async with self.session() as session:
554
+ await session.execute(insert_stmt, values)
555
+ self.logger.info(
556
+ f"✅ Successfully added {len(documents)} documents to '{schema}.{table}'"
557
+ )
558
+ except Exception as e:
559
+ self.logger.error(f"Error adding documents: {e}")
560
+ raise
561
+
562
+ def get_distance_strategy(
563
+ self,
564
+ embedding_column_obj,
565
+ query_embedding,
566
+ metric: str = None
567
+ ) -> Any:
568
+ """
569
+ Return the appropriate distance expression based on the metric or configured strategy.
570
+
571
+ Args:
572
+ embedding_column_obj: The SQLAlchemy column object for embeddings
573
+ query_embedding: The query embedding vector
574
+ metric: Optional metric string ('COSINE', 'L2', 'IP', 'DOT')
575
+ - if None, uses self.distance_strategy
576
+ """
577
+ # Use provided metric or fall back to instance distance_strategy
578
+ strategy = metric or self.distance_strategy
579
+ # self.logger.debug(
580
+ # f"PgVector: using distance strategy → {strategy}"
581
+ # )
582
+
583
+ # Convert string metrics to DistanceStrategy enum if needed
584
+ if isinstance(strategy, str):
585
+ metric_mapping = {
586
+ 'COSINE': DistanceStrategy.COSINE,
587
+ 'L2': DistanceStrategy.EUCLIDEAN_DISTANCE,
588
+ 'EUCLIDEAN': DistanceStrategy.EUCLIDEAN_DISTANCE,
589
+ 'IP': DistanceStrategy.MAX_INNER_PRODUCT,
590
+ 'DOT': DistanceStrategy.DOT_PRODUCT,
591
+ 'DOT_PRODUCT': DistanceStrategy.DOT_PRODUCT,
592
+ 'MAX_INNER_PRODUCT': DistanceStrategy.MAX_INNER_PRODUCT
593
+ }
594
+ strategy = metric_mapping.get(strategy.upper(), DistanceStrategy.COSINE)
595
+
596
+ # self.logger.debug(
597
+ # f"PgVector: using distance strategy → {strategy}"
598
+ # )
599
+
600
+ # Convert numpy array to list if needed
601
+ if isinstance(query_embedding, np.ndarray):
602
+ query_embedding = query_embedding.tolist()
603
+
604
+ if strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
605
+ return embedding_column_obj.l2_distance(query_embedding)
606
+ elif strategy == DistanceStrategy.COSINE:
607
+ return embedding_column_obj.cosine_distance(query_embedding)
608
+ elif strategy == DistanceStrategy.MAX_INNER_PRODUCT:
609
+ return embedding_column_obj.max_inner_product(query_embedding)
610
+ elif strategy == DistanceStrategy.DOT_PRODUCT:
611
+ # Note: pgvector doesn't have dot_product, using max_inner_product
612
+ return embedding_column_obj.max_inner_product(query_embedding)
613
+ else:
614
+ raise ValueError(
615
+ f"Got unexpected value for distance: {strategy}. "
616
+ f"Should be one of {', '.join([ds.value for ds in DistanceStrategy])}."
617
+ )
618
+
619
+ async def similarity_search(
620
+ self,
621
+ query: str,
622
+ table: str = None,
623
+ schema: str = None,
624
+ k: Optional[int] = None,
625
+ limit: int = None,
626
+ metadata_filters: Optional[Dict[str, Any]] = None,
627
+ score_threshold: Optional[float] = None,
628
+ metric: str = None,
629
+ embedding_column: str = 'embedding',
630
+ content_column: str = 'document',
631
+ metadata_column: str = 'cmetadata',
632
+ id_column: str = 'id',
633
+ additional_columns: Optional[List[str]] = None
634
+ ) -> List[SearchResult]:
635
+ """
636
+ Perform similarity search with optional threshold filtering.
637
+
638
+ Args:
639
+ query: The search query text
640
+ table: Table name (optional, uses default if not provided)
641
+ schema: Schema name (optional, uses default if not provided)
642
+ limit: Maximum number of results to return
643
+ score_threshold: Maximum distance threshold
644
+ results with distance > threshold will be filtered out)
645
+ metadata_filters: Dictionary of metadata filters to apply
646
+ metric: Distance metric to use ('COSINE', 'L2', 'IP')
647
+ embedding_column: Name of the embedding column
648
+ content_column: Name of the content column
649
+ metadata_column: Name of the metadata column
650
+ id_column: Name of the ID column
651
+ additional_columns: List of additional columns to include in results.
652
+ Returns:
653
+ List of SearchResult objects with content, metadata, score, collection_id, and record_id
654
+ """
655
+ if not self._connected:
656
+ await self.connection()
657
+
658
+ table = table or self.table_name
659
+ schema = schema or self.schema
660
+
661
+ if k and not limit:
662
+ limit = k
663
+ if not limit:
664
+ limit = 10
665
+
666
+ # Step 1: Ensure the ORM class exists
667
+ if not self.embedding_store:
668
+ self.embedding_store = self._define_collection_store(
669
+ table=table,
670
+ schema=schema,
671
+ dimension=self.dimension,
672
+ id_column=self._id_column,
673
+ embedding_column=embedding_column,
674
+ document_column=content_column,
675
+ metadata_column=metadata_column,
676
+ text_column=self._text_column,
677
+ )
678
+
679
+ # Step 2: Embed the query
680
+ query_embedding = self._embed_.embed_query(query)
681
+
682
+ # Get the actual column objects
683
+ content_col = getattr(self.embedding_store, content_column)
684
+ metadata_col = getattr(self.embedding_store, metadata_column)
685
+ embedding_col = getattr(self.embedding_store, embedding_column)
686
+ id_col = getattr(self.embedding_store, id_column)
687
+ collection_id_col = getattr(self.embedding_store, 'collection_id')
688
+
689
+ # Get the distance expression using the appropriate method
690
+ distance_expr = self.get_distance_strategy(
691
+ embedding_col,
692
+ query_embedding,
693
+ metric=metric
694
+ ).label("distance")
695
+ # self.logger.debug(f"Compiled distance expr → {distance_expr}")
696
+
697
+
698
+ # Build the select columns list
699
+ select_columns = [
700
+ id_col,
701
+ content_col,
702
+ metadata_col,
703
+ distance_expr,
704
+ collection_id_col,
705
+ ]
706
+
707
+ # Add additional columns dynamically using literal_column (no validation)
708
+ if additional_columns:
709
+ for col_name in additional_columns:
710
+ # Use literal_column to reference any column name without ORM validation
711
+ additional_col = literal_column(f'"{col_name}"').label(col_name)
712
+ select_columns.append(additional_col)
713
+ self.logger.debug(f"Added dynamic column: {col_name}")
714
+
715
+ # Step 5: Construct statement
716
+ stmt = (
717
+ select(*select_columns)
718
+ .select_from(self.embedding_store) # Explicitly specify the table
719
+ .order_by(asc(distance_expr))
720
+ )
721
+
722
+ # Apply threshold filter if provided
723
+ if score_threshold is not None:
724
+ stmt = stmt.where(distance_expr <= score_threshold)
725
+
726
+ if limit:
727
+ stmt = stmt.limit(limit)
728
+
729
+ # 6) Apply any JSONB metadata filters
730
+ if metadata_filters:
731
+ for key, val in metadata_filters.items():
732
+ if isinstance(val, bool):
733
+ # Handle boolean values properly in JSONB
734
+ stmt = stmt.where(
735
+ metadata_col[key].astext.cast(sqlalchemy.Boolean) == val
736
+ )
737
+ else:
738
+ stmt = stmt.where(
739
+ metadata_col[key].astext == str(val)
740
+ )
741
+
742
+ try:
743
+ # Execute query
744
+ async with self.session() as session:
745
+ result = await session.execute(stmt)
746
+ rows = result.fetchall()
747
+ # Create enhanced SearchResult objects
748
+ results = []
749
+ for row in rows:
750
+ metadata = row[2]
751
+ metadata['collection_id'] = row[4]
752
+ # Add additional columns as a dictionary (starting from index 5)
753
+ if additional_columns:
754
+ for i, col_name in enumerate(additional_columns):
755
+ metadata[col_name] = row[5 + i]
756
+ # Create an enhanced SearchResult with additional fields
757
+ search_result = SearchResult(
758
+ id=row[0],
759
+ content=row[1], # content_col
760
+ metadata=metadata, # metadata_col
761
+ score=row[3] # distance
762
+ )
763
+ results.append(search_result)
764
+
765
+ return results
766
+ except Exception as e:
767
+ self.logger.error(f"Error during similarity search: {e}")
768
+ raise
769
+
770
+ def get_vector(self, metric_type: str = None, **kwargs):
771
+ raise NotImplementedError("This method is part of the old implementation.")
772
+
773
+ async def drop_collection(self, table: str, schema: str = 'public') -> None:
774
+ """
775
+ Drops the specified table in the given schema.
776
+
777
+ Args:
778
+ table: The name of the table to drop.
779
+ schema: The database schema where the table resides.
780
+ """
781
+ if not self._connected:
782
+ await self.connection()
783
+
784
+ full_table_name = f"{schema}.{table}"
785
+ async with self._connection.begin() as conn:
786
+ await conn.execute(text(f"DROP TABLE IF EXISTS {full_table_name}"))
787
+ self.logger.info(f"Table '{full_table_name}' dropped successfully.")
788
+
789
+ async def prepare_embedding_table(
790
+ self,
791
+ table: str,
792
+ schema: str = 'public',
793
+ conn: AsyncEngine = None,
794
+ id_column: str = 'id',
795
+ embedding_column: str = 'embedding',
796
+ document_column: str = 'document',
797
+ metadata_column: str = 'cmetadata',
798
+ dimension: int = 768,
799
+ colbert_dimension: int = 128, # ColBERT token dimension
800
+ use_jsonb: bool = True,
801
+ drop_columns: bool = False,
802
+ create_all_indexes: bool = True,
803
+ **kwargs
804
+ ):
805
+ """
806
+ Prepare a Postgres Table as an embedding table in PostgreSQL with advanced features.
807
+ This method prepares a table with the following columns:
808
+ - id: unique identifier (String)
809
+ - embedding: the vector column (Vector(dimension) or JSONB)
810
+ - document: text column containing the document
811
+ - collection_id: UUID column for collection identification.
812
+ - metadata: JSONB column for metadata
813
+ - Additional columns based on the provided `columns` list
814
+ - Enhanced indexing strategies for efficient querying
815
+ - Support for multiple distance strategies (COSINE, L2, IP, etc.)
816
+ Args:
817
+ - tablename (str): Name of the table to create.
818
+ - embedding_column (str): Name of the column for storing embeddings.
819
+ - document_column (str): Name of the column for storing document text.
820
+ - metadata_column (str): Name of the column for storing metadata.
821
+ - dimension (int): Dimension of the embedding vector.
822
+ - id_column (str): Name of the column for storing unique identifiers.
823
+ - use_jsonb (bool): Whether to use JSONB for metadata storage.
824
+ - drop_columns (bool): Whether to drop existing columns.
825
+ - create_all_indexes (bool): Whether to create all distance strategies.
826
+ """
827
+ tablename = f"{schema}.{table}"
828
+ # Drop existing columns if requested
829
+ if drop_columns:
830
+ columns_to_drop = [
831
+ document_column, embedding_column, metadata_column,
832
+ 'token_embeddings', 'num_tokens'
833
+ ]
834
+ for column in columns_to_drop:
835
+ await conn.execute(
836
+ sqlalchemy.text(
837
+ f'ALTER TABLE {tablename} DROP COLUMN IF EXISTS {column};'
838
+ )
839
+ )
840
+ # Create metadata column as a jsonb field
841
+ if use_jsonb:
842
+ await conn.execute(
843
+ sqlalchemy.text(
844
+ f'ALTER TABLE {tablename} ADD COLUMN IF NOT EXISTS {metadata_column} JSONB;'
845
+ )
846
+ )
847
+ # Use pgvector type for dense embeddings
848
+ await conn.execute(
849
+ sqlalchemy.text(
850
+ f'ALTER TABLE {tablename} ADD COLUMN IF NOT EXISTS {embedding_column} vector({dimension});'
851
+ )
852
+ )
853
+ # Add ColBERT columns for token-level embeddings
854
+ await conn.execute(
855
+ sqlalchemy.text(
856
+ f'ALTER TABLE {tablename} ADD COLUMN IF NOT EXISTS token_embeddings vector({colbert_dimension})[];'
857
+ )
858
+ )
859
+ await conn.execute(
860
+ sqlalchemy.text(
861
+ f'ALTER TABLE {tablename} ADD COLUMN IF NOT EXISTS num_tokens INTEGER;'
862
+ )
863
+ )
864
+ # Create the additional columns
865
+ for col_name, col_type in [
866
+ (document_column, 'TEXT'),
867
+ (id_column, 'varchar'),
868
+ ]:
869
+ await conn.execute(
870
+ sqlalchemy.text(
871
+ f'ALTER TABLE {tablename} ADD COLUMN IF NOT EXISTS {col_name} {col_type};'
872
+ )
873
+ )
874
+ # Create the Collection Column:
875
+ await conn.execute(
876
+ sqlalchemy.text(
877
+ f"ALTER TABLE {tablename} ADD COLUMN IF NOT EXISTS collection_id UUID;"
878
+ )
879
+ )
880
+ await conn.execute(
881
+ sqlalchemy.text(
882
+ f"ALTER TABLE {tablename} ADD COLUMN IF NOT EXISTS collection_id UUID DEFAULT uuid_generate_v4();"
883
+ )
884
+ )
885
+ # Set the value on null values before declaring not null:
886
+ await conn.execute(
887
+ sqlalchemy.text(
888
+ f"UPDATE {tablename} SET collection_id = uuid_generate_v4() WHERE collection_id IS NULL;"
889
+ )
890
+ )
891
+ await conn.execute(
892
+ sqlalchemy.text(
893
+ f"ALTER TABLE {tablename} ALTER COLUMN collection_id SET NOT NULL;"
894
+ )
895
+ )
896
+ await conn.execute(
897
+ sqlalchemy.text(
898
+ f"CREATE UNIQUE INDEX IF NOT EXISTS idx_{table}_{schema}_collection_id ON {tablename} (collection_id);"
899
+ )
900
+ )
901
+ # ✅ CREATE COMPREHENSIVE INDEXES
902
+ if create_all_indexes:
903
+ await self._create_all_indexes(conn, tablename, embedding_column)
904
+ else:
905
+ # Create index only for current strategy
906
+ distance_strategy_ops = {
907
+ DistanceStrategy.COSINE: "vector_cosine_ops",
908
+ DistanceStrategy.EUCLIDEAN_DISTANCE: "vector_l2_ops",
909
+ DistanceStrategy.MAX_INNER_PRODUCT: "vector_ip_ops",
910
+ DistanceStrategy.DOT_PRODUCT: "vector_ip_ops"
911
+ }
912
+
913
+ ops = distance_strategy_ops.get(self.distance_strategy, "vector_cosine_ops")
914
+ strategy_name = str(self.distance_strategy).rsplit('.', maxsplit=1)[-1].lower()
915
+
916
+ await conn.execute(
917
+ sqlalchemy.text(
918
+ f"CREATE INDEX IF NOT EXISTS idx_{tablename.replace('.', '_')}_{strategy_name} "
919
+ f"ON {tablename} USING ivfflat ({embedding_column} {ops});"
920
+ )
921
+ )
922
+ print(f"✅ Created {strategy_name.upper()} index")
923
+
924
+ # Create ColBERT-specific indexes
925
+ await self._create_colbert_indexes(conn, tablename)
926
+
927
+ # Create JSONB indexes for better performance
928
+ await self._create_jsonb_indexes(
929
+ conn,
930
+ tablename,
931
+ metadata_column,
932
+ id_column
933
+ )
934
+ # Ensure the table is ready for embedding operations
935
+ self.embedding_store = self._define_collection_store(
936
+ table=table,
937
+ schema=schema,
938
+ dimension=dimension,
939
+ id_column=id_column,
940
+ embedding_column=embedding_column,
941
+ document_column=self._document_column
942
+ )
943
+ return True
944
+
945
+ async def _create_all_indexes(self, conn, tablename: str, embedding_column: str):
946
+ """Create all standard vector indexes."""
947
+ print("🔧 Creating indexes for all distance strategies...")
948
+
949
+ # COSINE index (most common for text embeddings)
950
+ await conn.execute(
951
+ sqlalchemy.text(
952
+ f"CREATE INDEX IF NOT EXISTS idx_{tablename.replace('.', '_')}_cosine "
953
+ f"ON {tablename} USING ivfflat ({embedding_column} vector_cosine_ops);"
954
+ )
955
+ )
956
+ print("✅ Created COSINE index")
957
+
958
+ # L2/Euclidean index
959
+ await conn.execute(
960
+ sqlalchemy.text(
961
+ f"CREATE INDEX IF NOT EXISTS idx_{tablename.replace('.', '_')}_l2 "
962
+ f"ON {tablename} USING ivfflat ({embedding_column} vector_l2_ops);"
963
+ )
964
+ )
965
+ print("✅ Created L2 index")
966
+
967
+ # Inner Product index
968
+ try:
969
+ await conn.execute(
970
+ sqlalchemy.text(
971
+ f"CREATE INDEX IF NOT EXISTS idx_{tablename.replace('.', '_')}_ip "
972
+ f"ON {tablename} USING ivfflat ({embedding_column} vector_ip_ops);"
973
+ )
974
+ )
975
+ print("✅ Created Inner Product index")
976
+ except Exception as e:
977
+ print(f"⚠️ Inner Product index creation failed: {e}")
978
+
979
+ # HNSW indexes for better performance (requires more memory)
980
+ try:
981
+ await conn.execute(
982
+ sqlalchemy.text(
983
+ f"CREATE INDEX IF NOT EXISTS idx_{tablename.replace('.', '_')}_hnsw_cosine "
984
+ f"ON {tablename} USING hnsw ({embedding_column} vector_cosine_ops);"
985
+ )
986
+ )
987
+ print("✅ Created HNSW COSINE index")
988
+
989
+ await conn.execute(
990
+ sqlalchemy.text(
991
+ f"""CREATE INDEX IF NOT EXISTS idx_{tablename.replace('.', '_')}_hnsw_l2
992
+ ON {tablename} USING hnsw ({embedding_column} vector_l2_ops) WITH (
993
+ m = 16, -- graph connectivity (higher → better recall, more memory)
994
+ ef_construction = 200 -- controls indexing time vs. recall
995
+ );"""
996
+ )
997
+ )
998
+ print("✅ Created HNSW EUCLIDEAN index")
999
+ except Exception as e:
1000
+ print(f"⚠️ HNSW index creation failed (this is optional): {e}")
1001
+
1002
+ async def _create_colbert_indexes(self, conn, tablename: str):
1003
+ """Create ColBERT-specific indexes for token embeddings."""
1004
+ print("🔧 Creating ColBERT indexes...")
1005
+
1006
+ try:
1007
+ # GIN index for array operations on token embeddings
1008
+ await conn.execute(
1009
+ sqlalchemy.text(
1010
+ f"CREATE INDEX IF NOT EXISTS idx_{tablename.replace('.', '_')}_token_embeddings_gin "
1011
+ f"ON {tablename} USING gin(token_embeddings);"
1012
+ )
1013
+ )
1014
+ print("✅ Created GIN index for token embeddings")
1015
+
1016
+ # Index on num_tokens for filtering
1017
+ await conn.execute(
1018
+ sqlalchemy.text(
1019
+ f"CREATE INDEX IF NOT EXISTS idx_{tablename.replace('.', '_')}_num_tokens "
1020
+ f"ON {tablename} (num_tokens);"
1021
+ )
1022
+ )
1023
+ print("✅ Created index for num_tokens")
1024
+
1025
+ # Partial index for non-null token embeddings
1026
+ await conn.execute(
1027
+ sqlalchemy.text(
1028
+ f"CREATE INDEX IF NOT EXISTS idx_{tablename.replace('.', '_')}_has_tokens "
1029
+ f"ON {tablename} (id) WHERE token_embeddings IS NOT NULL;"
1030
+ )
1031
+ )
1032
+ print("✅ Created partial index for documents with token embeddings")
1033
+
1034
+ except Exception as e:
1035
+ print(f"⚠️ ColBERT index creation failed: {e}")
1036
+
1037
+ async def create_embedding_table(
1038
+ self,
1039
+ table: str,
1040
+ columns: List[str],
1041
+ schema: str = 'public',
1042
+ embedding_column: str = 'embedding',
1043
+ document_column: str = 'document',
1044
+ metadata_column: str = 'cmetadata',
1045
+ dimension: int = None,
1046
+ id_column: str = 'id',
1047
+ use_jsonb: bool = False,
1048
+ drop_columns: bool = True,
1049
+ create_all_indexes: bool = True,
1050
+ **kwargs
1051
+ ):
1052
+ """
1053
+ Create an embedding table in PostgreSQL with advanced features.
1054
+ This method creates a table with the following columns:
1055
+ - id: unique identifier (String)
1056
+ - embedding: the vector column (Vector(dimension) or JSONB)
1057
+ - document: text column containing the document
1058
+ - cmetadata: JSONB column for metadata
1059
+ - Additional columns based on the provided `columns` list
1060
+ - Enhanced indexing strategies for efficient querying
1061
+ - Support for multiple distance strategies (COSINE, L2, IP, etc.)
1062
+ Args:
1063
+ - table (str): Name of the table to create.
1064
+ - columns (List[str]): List of column names to include in the table.
1065
+ - schema (str): Database schema where the table will be created.
1066
+ - embedding_column (str): Name of the column for storing embeddings.
1067
+ - document_column (str): Name of the column for storing document text.
1068
+ - metadata_column (str): Name of the column for storing metadata.
1069
+ - dimension (int): Dimension of the embedding vector.
1070
+ - id_column (str): Name of the column for storing unique identifiers.
1071
+ - use_jsonb (bool): Whether to use JSONB for metadata storage.
1072
+ - drop_columns (bool): Whether to drop existing columns.
1073
+ - create_all_indexes (bool): Whether to create all distance strategies.
1074
+
1075
+ Enhanced embedding table creation with JSONB strategy for better semantic search.
1076
+
1077
+ This approach creates multiple document representations:
1078
+ 1. Primary search content (emphasizing store ID)
1079
+ 2. Location-based content
1080
+ 3. Structured metadata for filtering
1081
+ 4. Multiple embedding variations
1082
+ """
1083
+ tablename = f'{schema}.{table}'
1084
+ cols = ', '.join(columns)
1085
+ _qry = f'SELECT {cols} FROM {tablename};'
1086
+ dimension = dimension or self.dimension
1087
+
1088
+ # Generate a sample embedding to determine its dimension
1089
+ sample_vector = self._embed_.embedding.embed_query("sample text")
1090
+ vector_dim = len(sample_vector)
1091
+ self.logger.notice(
1092
+ f"USING EMBED {self._embed_} with dimension {vector_dim}"
1093
+ )
1094
+
1095
+ if vector_dim != dimension:
1096
+ raise ValueError(
1097
+ f"Expected embedding dimension {dimension}, but got {vector_dim}"
1098
+ )
1099
+
1100
+ async with self._connection.begin() as conn:
1101
+ result = await conn.execute(sqlalchemy.text(_qry))
1102
+ rows = result.fetchall()
1103
+
1104
+ await self.prepare_embedding_table(
1105
+ table=table,
1106
+ schema=schema,
1107
+ embedding_column=embedding_column,
1108
+ document_column=document_column,
1109
+ metadata_column=metadata_column,
1110
+ dimension=dimension,
1111
+ id_column=id_column,
1112
+ use_jsonb=use_jsonb,
1113
+ drop_columns=drop_columns,
1114
+ create_all_indexes=create_all_indexes,
1115
+ **kwargs
1116
+ )
1117
+
1118
+ # Populate the embedding data
1119
+ for i, row in enumerate(rows):
1120
+ _id = getattr(row, id_column)
1121
+ metadata = {col: getattr(row, col) for col in columns}
1122
+ data = await self._create_metadata_structure(metadata, id_column, _id)
1123
+
1124
+ # Generate embedding
1125
+ searchable_text = data['structured_search']
1126
+ print(f"🔍 Row {i + 1}/{len(rows)} - {_id}")
1127
+ print(f" Text: {searchable_text[:100]}...")
1128
+
1129
+ vector = self._embed_.embedding.embed_query(searchable_text)
1130
+ vector_str = "[" + ",".join(str(v) for v in vector) + "]"
1131
+
1132
+ await conn.execute(
1133
+ sqlalchemy.text(f"""
1134
+ UPDATE {tablename}
1135
+ SET {embedding_column} = :embeddings,
1136
+ {document_column} = :document,
1137
+ {metadata_column} = :metadata
1138
+ WHERE {id_column} = :id
1139
+ """),
1140
+ {
1141
+ "embeddings": vector_str,
1142
+ "document": searchable_text,
1143
+ "metadata": json_encoder(data),
1144
+ "id": _id
1145
+ }
1146
+ )
1147
+
1148
+ print("✅ Updated Table embeddings with comprehensive indexes.")
1149
+
1150
+ def _create_natural_searchable_text(
1151
+ self,
1152
+ metadata: dict,
1153
+ id_column: str,
1154
+ record_id: str
1155
+ ) -> str:
1156
+ """
1157
+ Create well-structured, natural language text with proper separation.
1158
+
1159
+ This creates clean, readable text that embedding models can understand better.
1160
+ """
1161
+ # Start with the ID in multiple formats for exact matching
1162
+ text_parts = [
1163
+ f"ID: {record_id}",
1164
+ f"Identifier: {record_id}",
1165
+ id_column + ": " + record_id
1166
+ ]
1167
+
1168
+ # Process each field to create natural language descriptions
1169
+ for key, value in metadata.items():
1170
+ if value is None or value == '':
1171
+ continue
1172
+ clean_value = value.strip() if isinstance(value, str) else str(value)
1173
+ text_parts.append(f"{key}: {clean_value}")
1174
+ # Add the field in natural language format
1175
+ clean_key = key.replace('_', ' ').title()
1176
+ text_parts.append(f"{clean_key}={clean_value}")
1177
+
1178
+ # Join with spaces and clean up
1179
+ searchable_text = ', '.join(text_parts) + '.'
1180
+
1181
+ return searchable_text
1182
+
1183
+ def _create_structured_search_text(self, metadata: dict, id_column: str, record_id: str) -> str:
1184
+ """
1185
+ Create a more structured but still readable search text.
1186
+
1187
+ This emphasizes key-value relationships while staying readable.
1188
+ """
1189
+ # ID section with emphasis
1190
+ kv_sections = [
1191
+ f"ID: {record_id}",
1192
+ f"Identifier: {record_id}",
1193
+ id_column + ": " + record_id
1194
+ ]
1195
+
1196
+ # Key-value sections with clean separation
1197
+ for key, value in metadata.items():
1198
+ if value is None or value == '':
1199
+ continue
1200
+
1201
+ # Clean key-value representation
1202
+ clean_key = key.replace('_', ' ').title()
1203
+ kv_sections.append(f"{clean_key}: {value}")
1204
+ kv_sections.append(f"{key}: {value}")
1205
+
1206
+ # Combine with proper separation
1207
+ return ' | '.join(kv_sections)
1208
+
1209
+ async def _create_metadata_structure(
1210
+ self,
1211
+ metadata: dict,
1212
+ id_column: str,
1213
+ _id: str
1214
+ ):
1215
+ """Create a structured metadata representation for the document."""
1216
+ # Create a structured metadata representation
1217
+ enhanced_metadata = {
1218
+ "id": _id,
1219
+ id_column: _id,
1220
+ "_variants": [
1221
+ _id,
1222
+ _id.lower(),
1223
+ _id.upper()
1224
+ ]
1225
+ }
1226
+ for key, value in metadata.items():
1227
+ enhanced_metadata[key] = value
1228
+ # Create searchable variants for key fields
1229
+ if value and isinstance(value, str):
1230
+ variants = [value, value.lower(), value.upper()]
1231
+ # Add variants without special characters
1232
+ clean_value = ''.join(c for c in str(value) if c.isalnum() or c.isspace())
1233
+ if clean_value != value:
1234
+ variants.append(clean_value)
1235
+ enhanced_metadata[f"_{key}_variants"] = list(set(variants))
1236
+ # create a full-text search field of searchable content
1237
+ enhanced_metadata['searchable_content'] = self._create_natural_searchable_text(
1238
+ metadata, id_column, _id
1239
+ )
1240
+
1241
+ # Also create a structured search text that emphasizes important fields
1242
+ enhanced_metadata['structured_search'] = self._create_structured_search_text(
1243
+ metadata, id_column, _id
1244
+ )
1245
+
1246
+ return enhanced_metadata
1247
+
1248
+ async def _create_jsonb_indexes(
1249
+ self,
1250
+ conn,
1251
+ tablename: str,
1252
+ metadata_col: str,
1253
+ id_column: str
1254
+ ):
1255
+ """Create optimized JSONB indexes for better search performance."""
1256
+
1257
+ print("🔧 Creating JSONB indexes on Metadata for optimized search...")
1258
+
1259
+ # Index for ID searches
1260
+ await conn.execute(
1261
+ sqlalchemy.text(f"""
1262
+ CREATE INDEX IF NOT EXISTS idx_{tablename.replace('.', '_')}_{id_column}
1263
+ ON {tablename} USING BTREE (({metadata_col}->>'{id_column}'));
1264
+ """)
1265
+ )
1266
+ await conn.execute(
1267
+ sqlalchemy.text(f"""
1268
+ CREATE INDEX IF NOT EXISTS idx_{tablename.replace('.', '_')}_id
1269
+ ON {tablename} USING BTREE (({metadata_col}->>'id'));
1270
+ """)
1271
+ )
1272
+
1273
+ # GIN index for full-text search on searchable content
1274
+ await conn.execute(
1275
+ sqlalchemy.text(f"""
1276
+ CREATE INDEX IF NOT EXISTS idx_{tablename.replace('.', '_')}_fulltext
1277
+ ON {tablename} USING GIN (to_tsvector('english', {metadata_col}->>'searchable_content'));
1278
+ """)
1279
+ )
1280
+
1281
+ # GIN index for JSONB structure searches
1282
+ await conn.execute(
1283
+ sqlalchemy.text(f"""
1284
+ CREATE INDEX IF NOT EXISTS idx_{tablename.replace('.', '_')}_metadata_gin
1285
+ ON {tablename} USING GIN ({metadata_col});
1286
+ """)
1287
+ )
1288
+ print("✅ Created optimized JSONB indexes")
1289
+
1290
+ async def add_colbert_document(
1291
+ self,
1292
+ document_id: str,
1293
+ content: str,
1294
+ token_embeddings: np.ndarray,
1295
+ table: str,
1296
+ schema: str = 'public',
1297
+ metadata: Optional[Dict[str, Any]] = None,
1298
+ document_column: str = 'document',
1299
+ metadata_column: str = 'cmetadata',
1300
+ id_column: str = 'id',
1301
+ **kwargs
1302
+ ) -> None:
1303
+ """
1304
+ Add a document with ColBERT token embeddings to the specified table.
1305
+
1306
+ Args:
1307
+ document_id: Unique identifier for the document
1308
+ content: The document text content
1309
+ token_embeddings: NumPy array of token embeddings (shape: [num_tokens, embedding_dim])
1310
+ table: The name of the table
1311
+ schema: The database schema where the table resides
1312
+ metadata: Optional metadata dictionary
1313
+ document_column: Name of the document content column
1314
+ metadata_column: Name of the metadata column
1315
+ id_column: Name of the ID column
1316
+ """
1317
+ if not self._connected:
1318
+ await self.connection()
1319
+
1320
+ # Ensure the ORM table is initialized
1321
+ if self.embedding_store is None:
1322
+ self.embedding_store = self._define_collection_store(
1323
+ table=table,
1324
+ schema=schema,
1325
+ dimension=self.dimension,
1326
+ id_column=id_column,
1327
+ document_column=document_column,
1328
+ metadata_column=metadata_column,
1329
+ text_column=self._text_column,
1330
+ )
1331
+
1332
+ # Convert numpy array to list format for PostgreSQL
1333
+ if isinstance(token_embeddings, np.ndarray):
1334
+ token_embeddings_list = token_embeddings.tolist()
1335
+ else:
1336
+ token_embeddings_list = token_embeddings
1337
+
1338
+ num_tokens = len(token_embeddings_list)
1339
+
1340
+ # Prepare the insert/upsert data
1341
+ values = {
1342
+ id_column: document_id,
1343
+ document_column: content,
1344
+ 'token_embeddings': token_embeddings_list,
1345
+ 'num_tokens': num_tokens,
1346
+ metadata_column: metadata or {}
1347
+ }
1348
+
1349
+ # Build insert statement with upsert capability
1350
+ insert_stmt = insert(self.embedding_store).values(values)
1351
+
1352
+ # Create upsert statement (ON CONFLICT DO UPDATE)
1353
+ upsert_stmt = insert_stmt.on_conflict_do_update(
1354
+ index_elements=[id_column],
1355
+ set_={
1356
+ # document_column: insert_stmt.excluded.__getattr__(document_column),
1357
+ document_column: getattr(insert_stmt.excluded, document_column),
1358
+ 'token_embeddings': insert_stmt.excluded.token_embeddings,
1359
+ 'num_tokens': insert_stmt.excluded.num_tokens,
1360
+ metadata_column: getattr(insert_stmt.excluded, metadata_column),
1361
+ }
1362
+ )
1363
+
1364
+ try:
1365
+ async with self._connection.begin() as conn:
1366
+ await conn.execute(upsert_stmt)
1367
+
1368
+ self.logger.info(
1369
+ f"Successfully added ColBERT document '{document_id}' with {num_tokens} tokens to '{schema}.{table}'"
1370
+ )
1371
+ except Exception as e:
1372
+ self.logger.error(f"Error adding ColBERT document: {e}")
1373
+ raise
1374
+
1375
+ async def colbert_search(
1376
+ self,
1377
+ query_tokens: np.ndarray,
1378
+ table: str,
1379
+ schema: str = 'public',
1380
+ top_k: int = 10,
1381
+ metadata_filters: Optional[Dict[str, Any]] = None,
1382
+ min_tokens: Optional[int] = None,
1383
+ max_tokens: Optional[int] = None,
1384
+ id_column: str = 'id',
1385
+ document_column: str = 'document',
1386
+ metadata_column: str = 'cmetadata',
1387
+ additional_columns: Optional[List[str]] = None
1388
+ ) -> List[SearchResult]:
1389
+ """
1390
+ Perform ColBERT search with late interaction using MaxSim scoring.
1391
+
1392
+ Args:
1393
+ query_tokens: NumPy array of query token embeddings (shape: [num_query_tokens, embedding_dim])
1394
+ table: Table name
1395
+ schema: Schema name
1396
+ top_k: Number of results to return
1397
+ metadata_filters: Optional metadata filters
1398
+ min_tokens: Minimum number of tokens in documents to consider
1399
+ max_tokens: Maximum number of tokens in documents to consider
1400
+ id_column: Name of the ID column
1401
+ document_column: Name of the document content column
1402
+ metadata_column: Name of the metadata column
1403
+ additional_columns: Additional columns to include in results
1404
+
1405
+ Returns:
1406
+ List of SearchResult objects ordered by ColBERT score (descending)
1407
+ """
1408
+ if not self._connected:
1409
+ await self.connection()
1410
+
1411
+ # Ensure the ORM table is initialized
1412
+ if self.embedding_store is None:
1413
+ self.embedding_store = self._define_collection_store(
1414
+ table=table,
1415
+ schema=schema,
1416
+ dimension=self.dimension,
1417
+ id_column=id_column,
1418
+ document_column=document_column,
1419
+ metadata_column=metadata_column,
1420
+ text_column=self._text_column,
1421
+ )
1422
+
1423
+ # Convert query tokens to list format
1424
+ if isinstance(query_tokens, np.ndarray):
1425
+ query_tokens_list = query_tokens.tolist()
1426
+ else:
1427
+ query_tokens_list = query_tokens
1428
+
1429
+ # Get column objects
1430
+ id_col = getattr(self.embedding_store, id_column)
1431
+ content_col = getattr(self.embedding_store, document_column)
1432
+ metadata_col = getattr(self.embedding_store, metadata_column)
1433
+ token_embeddings_col = getattr(self.embedding_store, 'token_embeddings')
1434
+ num_tokens_col = getattr(self.embedding_store, 'num_tokens')
1435
+ collection_id_col = getattr(self.embedding_store, 'collection_id')
1436
+
1437
+ # Build select columns
1438
+ select_columns = [
1439
+ id_col,
1440
+ content_col,
1441
+ metadata_col,
1442
+ collection_id_col,
1443
+ func.max_sim(token_embeddings_col, query_tokens_list).label('colbert_score')
1444
+ ]
1445
+
1446
+ # Add additional columns dynamically
1447
+ if additional_columns:
1448
+ for col_name in additional_columns:
1449
+ additional_col = literal_column(f'"{col_name}"').label(col_name)
1450
+ select_columns.append(additional_col)
1451
+
1452
+ # Build the query
1453
+ stmt = (
1454
+ select(*select_columns)
1455
+ .select_from(self.embedding_store)
1456
+ .where(token_embeddings_col.isnot(None)) # Only documents with token embeddings
1457
+ .order_by(func.max_sim(token_embeddings_col, query_tokens_list).desc())
1458
+ .limit(top_k)
1459
+ )
1460
+
1461
+ # Apply token count filters
1462
+ if min_tokens is not None:
1463
+ stmt = stmt.where(num_tokens_col >= min_tokens)
1464
+ if max_tokens is not None:
1465
+ stmt = stmt.where(num_tokens_col <= max_tokens)
1466
+
1467
+ # Apply metadata filters
1468
+ if metadata_filters:
1469
+ for key, value in metadata_filters.items():
1470
+ stmt = stmt.where(metadata_col[key].astext == str(value))
1471
+
1472
+ try:
1473
+ async with self._connection.connect() as conn:
1474
+ result = await conn.execute(stmt)
1475
+ rows = result.fetchall()
1476
+
1477
+ # Create SearchResult objects
1478
+ results = []
1479
+ for row in rows:
1480
+ # Enhance metadata with additional info
1481
+ metadata = dict(row[2]) if row[2] else {}
1482
+ metadata['collection_id'] = row[3]
1483
+ metadata['colbert_score'] = float(row[4])
1484
+
1485
+ # Add additional columns to metadata
1486
+ if additional_columns:
1487
+ for i, col_name in enumerate(additional_columns):
1488
+ metadata[col_name] = row[5 + i]
1489
+
1490
+ search_result = SearchResult(
1491
+ id=row[0],
1492
+ content=row[1],
1493
+ metadata=metadata,
1494
+ score=float(row[4]) # ColBERT score
1495
+ )
1496
+ results.append(search_result)
1497
+
1498
+ self.logger.info(
1499
+ f"ColBERT search returned {len(results)} results from {schema}.{table}"
1500
+ )
1501
+ return results
1502
+
1503
+ except Exception as e:
1504
+ self.logger.error(f"Error during ColBERT search: {e}")
1505
+ raise
1506
+
1507
+ async def hybrid_search(
1508
+ self,
1509
+ query: str,
1510
+ query_tokens: Optional[np.ndarray] = None,
1511
+ table: str = None,
1512
+ schema: str = None,
1513
+ top_k: int = 10,
1514
+ dense_weight: float = 0.7,
1515
+ colbert_weight: float = 0.3,
1516
+ metadata_filters: Optional[Dict[str, Any]] = None,
1517
+ **kwargs
1518
+ ) -> List[SearchResult]:
1519
+ """
1520
+ Perform hybrid search combining dense embeddings and ColBERT token matching.
1521
+
1522
+ Args:
1523
+ query: Text query
1524
+ query_tokens: Optional pre-computed query token embeddings
1525
+ table: Table name
1526
+ schema: Schema name
1527
+ top_k: Number of final results
1528
+ dense_weight: Weight for dense similarity scores (0-1)
1529
+ colbert_weight: Weight for ColBERT scores (0-1)
1530
+ metadata_filters: Metadata filters to apply
1531
+
1532
+ Returns:
1533
+ List of SearchResult objects with combined scores
1534
+ """
1535
+ if not self._connected:
1536
+ await self.connection()
1537
+
1538
+ table = table or self.table_name
1539
+ schema = schema or self.schema
1540
+
1541
+ # Fetch more candidates for reranking
1542
+ candidate_count = min(top_k * 3, 100)
1543
+
1544
+ # Get dense similarity results
1545
+ dense_results = await self.similarity_search(
1546
+ query=query,
1547
+ table=table,
1548
+ schema=schema,
1549
+ limit=candidate_count,
1550
+ metadata_filters=metadata_filters,
1551
+ **kwargs
1552
+ )
1553
+
1554
+ # Get ColBERT results if token embeddings provided
1555
+ colbert_results = []
1556
+ if query_tokens is not None:
1557
+ colbert_results = await self.colbert_search(
1558
+ query_tokens=query_tokens,
1559
+ table=table,
1560
+ schema=schema,
1561
+ top_k=candidate_count,
1562
+ metadata_filters=metadata_filters
1563
+ )
1564
+
1565
+ # Combine and rerank results
1566
+ combined_results = self._combine_search_results(
1567
+ dense_results=dense_results,
1568
+ colbert_results=colbert_results,
1569
+ dense_weight=dense_weight,
1570
+ colbert_weight=colbert_weight
1571
+ )
1572
+
1573
+ # Return top-k results
1574
+ return combined_results[:top_k]
1575
+
1576
+ def _combine_search_results(
1577
+ self,
1578
+ dense_results: List[SearchResult],
1579
+ colbert_results: List[SearchResult],
1580
+ dense_weight: float,
1581
+ colbert_weight: float
1582
+ ) -> List[SearchResult]:
1583
+ """
1584
+ Combine and rerank results from dense and ColBERT searches.
1585
+ """
1586
+ # Create lookup dictionaries
1587
+ dense_lookup = {result.id: result for result in dense_results}
1588
+ colbert_lookup = {result.id: result for result in colbert_results}
1589
+
1590
+ # Get all unique document IDs
1591
+ all_ids = set(dense_lookup.keys()) | set(colbert_lookup.keys())
1592
+
1593
+ # Normalize scores to 0-1 range
1594
+ if dense_results:
1595
+ dense_scores = [r.score for r in dense_results]
1596
+ dense_min, dense_max = min(dense_scores), max(dense_scores)
1597
+ dense_range = dense_max - dense_min if dense_max != dense_min else 1
1598
+ else:
1599
+ dense_min, dense_range = 0, 1
1600
+
1601
+ if colbert_results:
1602
+ colbert_scores = [r.score for r in colbert_results]
1603
+ colbert_min, colbert_max = min(colbert_scores), max(colbert_scores)
1604
+ colbert_range = colbert_max - colbert_min if colbert_max != colbert_min else 1
1605
+ else:
1606
+ colbert_min, colbert_range = 0, 1
1607
+
1608
+ # Combine results
1609
+ combined_results = []
1610
+ for doc_id in all_ids:
1611
+ dense_result = dense_lookup.get(doc_id)
1612
+ colbert_result = colbert_lookup.get(doc_id)
1613
+
1614
+ # Normalize scores
1615
+ dense_score_norm = 0
1616
+ if dense_result:
1617
+ dense_score_norm = (dense_result.score - dense_min) / dense_range
1618
+
1619
+ colbert_score_norm = 0
1620
+ if colbert_result:
1621
+ colbert_score_norm = (colbert_result.score - colbert_min) / colbert_range
1622
+
1623
+ # Calculate combined score
1624
+ combined_score = (
1625
+ dense_weight * dense_score_norm +
1626
+ colbert_weight * colbert_score_norm
1627
+ )
1628
+
1629
+ # Use the result with more complete information
1630
+ primary_result = dense_result or colbert_result
1631
+
1632
+ # Create combined result
1633
+ combined_result = SearchResult(
1634
+ id=primary_result.id,
1635
+ content=primary_result.content,
1636
+ metadata={
1637
+ **primary_result.metadata,
1638
+ 'dense_score': dense_result.score if dense_result else 0,
1639
+ 'colbert_score': colbert_result.score if colbert_result else 0,
1640
+ 'combined_score': combined_score
1641
+ },
1642
+ score=combined_score
1643
+ )
1644
+ combined_results.append(combined_result)
1645
+
1646
+ # Sort by combined score (descending)
1647
+ combined_results.sort(key=lambda x: x.score, reverse=True)
1648
+
1649
+ return combined_results
1650
+
1651
+ async def mmr_search(
1652
+ self,
1653
+ query: str,
1654
+ table: str = None,
1655
+ schema: str = None,
1656
+ k: int = 10,
1657
+ fetch_k: int = None,
1658
+ lambda_mult: float = 0.5,
1659
+ metadata_filters: Optional[Dict[str, Any]] = None,
1660
+ score_threshold: Optional[float] = None,
1661
+ metric: str = None,
1662
+ embedding_column: str = 'embedding',
1663
+ content_column: str = 'document',
1664
+ metadata_column: str = 'cmetadata',
1665
+ id_column: str = 'id',
1666
+ additional_columns: Optional[List[str]] = None
1667
+ ) -> List[SearchResult]:
1668
+ """
1669
+ Perform Maximal Marginal Relevance (MMR) search to balance relevance and diversity.
1670
+
1671
+ MMR helps avoid redundant results by selecting documents that are relevant to the query
1672
+ but diverse from each other.
1673
+
1674
+ Args:
1675
+ query: The search query text
1676
+ table: Table name (optional, uses default if not provided)
1677
+ schema: Schema name (optional, uses default if not provided)
1678
+ k: Number of final results to return
1679
+ fetch_k: Number of candidate documents to fetch (default: k * 3)
1680
+ lambda_mult: MMR diversity parameter (0-1):
1681
+ - 1.0 = pure relevance (no diversity)
1682
+ - 0.0 = pure diversity (no relevance)
1683
+ - 0.5 = balanced (default)
1684
+ metadata_filters: Dictionary of metadata filters to apply
1685
+ score_threshold: Maximum distance threshold for initial candidates
1686
+ metric: Distance metric to use ('COSINE', 'L2', 'IP')
1687
+ embedding_column: Name of the embedding column
1688
+ content_column: Name of the content column
1689
+ metadata_column: Name of the metadata column
1690
+ id_column: Name of the ID column
1691
+ additional_columns: Additional columns to include in results
1692
+
1693
+ Returns:
1694
+ List of SearchResult objects selected via MMR algorithm
1695
+ """
1696
+ if not self._connected:
1697
+ await self.connection()
1698
+
1699
+ # Default to fetching 3x more candidates than final results
1700
+ if fetch_k is None:
1701
+ fetch_k = max(k * 3, 20)
1702
+
1703
+ # Step 1: Get initial candidates using similarity search
1704
+ candidates = await self.similarity_search(
1705
+ query=query,
1706
+ table=table,
1707
+ schema=schema,
1708
+ limit=fetch_k,
1709
+ metadata_filters=metadata_filters,
1710
+ score_threshold=score_threshold,
1711
+ metric=metric,
1712
+ embedding_column=embedding_column,
1713
+ content_column=content_column,
1714
+ metadata_column=metadata_column,
1715
+ id_column=id_column,
1716
+ additional_columns=additional_columns
1717
+ )
1718
+
1719
+ if len(candidates) <= k:
1720
+ # If we have fewer candidates than requested results, return all
1721
+ return candidates
1722
+
1723
+ # Step 2: Get embeddings for MMR computation
1724
+ # We need to fetch the actual embedding vectors for similarity computation
1725
+ candidate_embeddings = await self._fetch_embeddings_for_mmr(
1726
+ candidate_ids=[result.id for result in candidates],
1727
+ table=table,
1728
+ schema=schema,
1729
+ embedding_column=embedding_column,
1730
+ id_column=id_column
1731
+ )
1732
+
1733
+ # Step 3: Get query embedding
1734
+ query_embedding = self._embed_.embed_query(query)
1735
+
1736
+ # Step 4: Run MMR algorithm
1737
+ selected_results = self._mmr_algorithm(
1738
+ query_embedding=query_embedding,
1739
+ candidates=candidates,
1740
+ candidate_embeddings=candidate_embeddings,
1741
+ k=k,
1742
+ lambda_mult=lambda_mult,
1743
+ metric=metric or self.distance_strategy
1744
+ )
1745
+
1746
+ self.logger.info(
1747
+ f"MMR search selected {len(selected_results)} results from {len(candidates)} candidates "
1748
+ f"(λ={lambda_mult})"
1749
+ )
1750
+
1751
+ return selected_results
1752
+
1753
+ async def _fetch_embeddings_for_mmr(
1754
+ self,
1755
+ candidate_ids: List[str],
1756
+ table: str,
1757
+ schema: str,
1758
+ embedding_column: str,
1759
+ id_column: str
1760
+ ) -> Dict[str, np.ndarray]:
1761
+ """
1762
+ Fetch embedding vectors for candidate documents.
1763
+
1764
+ Args:
1765
+ candidate_ids: List of document IDs to fetch embeddings for
1766
+ table: Table name
1767
+ schema: Schema name
1768
+ embedding_column: Name of the embedding column
1769
+ id_column: Name of the ID column
1770
+
1771
+ Returns:
1772
+ Dictionary mapping document ID to embedding vector
1773
+ """
1774
+ if not self.embedding_store:
1775
+ self.embedding_store = self._define_collection_store(
1776
+ table=table,
1777
+ schema=schema,
1778
+ dimension=self.dimension,
1779
+ id_column=self._id_column,
1780
+ embedding_column=embedding_column,
1781
+ document_column=self._document_column,
1782
+ metadata_column='cmetadata',
1783
+ text_column=self._text_column,
1784
+ )
1785
+
1786
+ # Get column objects
1787
+ id_col = getattr(self.embedding_store, id_column)
1788
+ embedding_col = getattr(self.embedding_store, embedding_column)
1789
+
1790
+ # Build query to fetch embeddings
1791
+ stmt = (
1792
+ select(id_col, embedding_col)
1793
+ .select_from(self.embedding_store)
1794
+ .where(id_col.in_(candidate_ids))
1795
+ )
1796
+
1797
+ embeddings_dict = {}
1798
+ async with self.session() as session:
1799
+ result = await session.execute(stmt)
1800
+ rows = result.fetchall()
1801
+
1802
+ for row in rows:
1803
+ doc_id = row[0]
1804
+ embedding = row[1]
1805
+
1806
+ # Convert to numpy array if needed
1807
+ if isinstance(embedding, (list, tuple)):
1808
+ embeddings_dict[doc_id] = np.array(embedding, dtype=np.float32)
1809
+ elif hasattr(embedding, '__array__'):
1810
+ embeddings_dict[doc_id] = np.array(embedding, dtype=np.float32)
1811
+ else:
1812
+ # Handle pgvector Vector type
1813
+ embeddings_dict[doc_id] = np.array(embedding, dtype=np.float32)
1814
+
1815
+ return embeddings_dict
1816
+
1817
+ def _mmr_algorithm(
1818
+ self,
1819
+ query_embedding: np.ndarray,
1820
+ candidates: List[SearchResult],
1821
+ candidate_embeddings: Dict[str, np.ndarray],
1822
+ k: int,
1823
+ lambda_mult: float,
1824
+ metric: str
1825
+ ) -> List[SearchResult]:
1826
+ """
1827
+ Core MMR algorithm implementation.
1828
+
1829
+ Args:
1830
+ query_embedding: Query embedding vector
1831
+ candidates: List of candidate SearchResult objects
1832
+ candidate_embeddings: Dictionary mapping doc ID to embedding vector
1833
+ k: Number of results to select
1834
+ lambda_mult: MMR diversity parameter (0-1)
1835
+ metric: Distance metric to use
1836
+
1837
+ Returns:
1838
+ List of selected SearchResult objects ordered by MMR score
1839
+ """
1840
+ if len(candidates) <= k:
1841
+ return candidates
1842
+
1843
+ # Convert query embedding to numpy array
1844
+ if not isinstance(query_embedding, np.ndarray):
1845
+ query_embedding = np.array(query_embedding, dtype=np.float32)
1846
+
1847
+ # Prepare data structures
1848
+ selected_indices = []
1849
+ remaining_indices = list(range(len(candidates)))
1850
+
1851
+ # Step 1: Select the most relevant document first
1852
+ query_similarities = []
1853
+ for candidate in candidates:
1854
+ doc_embedding = candidate_embeddings.get(candidate.id)
1855
+ if doc_embedding is not None:
1856
+ similarity = self._compute_similarity(query_embedding, doc_embedding, metric)
1857
+ query_similarities.append(similarity)
1858
+ else:
1859
+ # Fallback to distance score if embedding not available
1860
+ # Convert distance to similarity (lower distance = higher similarity)
1861
+ query_similarities.append(1.0 / (1.0 + candidate.score))
1862
+
1863
+ # Select the most similar document first
1864
+ best_idx = np.argmax(query_similarities)
1865
+ selected_indices.append(best_idx)
1866
+ remaining_indices.remove(best_idx)
1867
+
1868
+ # Step 2: Iteratively select remaining documents using MMR
1869
+ for _ in range(min(k - 1, len(remaining_indices))):
1870
+ mmr_scores = []
1871
+
1872
+ for idx in remaining_indices:
1873
+ candidate = candidates[idx]
1874
+ doc_embedding = candidate_embeddings.get(candidate.id)
1875
+
1876
+ if doc_embedding is None:
1877
+ # Fallback scoring if embedding not available
1878
+ mmr_score = lambda_mult * query_similarities[idx]
1879
+ mmr_scores.append(mmr_score)
1880
+ continue
1881
+
1882
+ # Relevance: similarity to query
1883
+ relevance = query_similarities[idx]
1884
+
1885
+ # Diversity: maximum similarity to already selected documents
1886
+ max_similarity_to_selected = 0.0
1887
+ for selected_idx in selected_indices:
1888
+ selected_candidate = candidates[selected_idx]
1889
+ selected_embedding = candidate_embeddings.get(selected_candidate.id)
1890
+
1891
+ if selected_embedding is not None:
1892
+ similarity = self._compute_similarity(doc_embedding, selected_embedding, metric)
1893
+ max_similarity_to_selected = max(max_similarity_to_selected, similarity)
1894
+
1895
+ # MMR formula: λ * relevance - (1-λ) * max_similarity_to_selected
1896
+ mmr_score = (
1897
+ lambda_mult * relevance -
1898
+ (1.0 - lambda_mult) * max_similarity_to_selected
1899
+ )
1900
+ mmr_scores.append(mmr_score)
1901
+
1902
+ # Select document with highest MMR score
1903
+ if mmr_scores:
1904
+ best_remaining_idx = np.argmax(mmr_scores)
1905
+ best_idx = remaining_indices[best_remaining_idx]
1906
+ selected_indices.append(best_idx)
1907
+ remaining_indices.remove(best_idx)
1908
+
1909
+ # Step 3: Return selected results with MMR scores in metadata
1910
+ selected_results = []
1911
+ for i, idx in enumerate(selected_indices):
1912
+ result = candidates[idx]
1913
+ # Add MMR ranking to metadata
1914
+ enhanced_metadata = dict(result.metadata)
1915
+ enhanced_metadata['mmr_rank'] = i + 1
1916
+ enhanced_metadata['mmr_lambda'] = lambda_mult
1917
+ enhanced_metadata['original_rank'] = idx + 1
1918
+
1919
+ enhanced_result = SearchResult(
1920
+ id=result.id,
1921
+ content=result.content,
1922
+ metadata=enhanced_metadata,
1923
+ score=result.score
1924
+ )
1925
+ selected_results.append(enhanced_result)
1926
+
1927
+ return selected_results
1928
+
1929
+ def _compute_similarity(
1930
+ self,
1931
+ embedding1: np.ndarray,
1932
+ embedding2: np.ndarray,
1933
+ metric: Union[str, Any]
1934
+ ) -> float:
1935
+ """
1936
+ Compute similarity between two embeddings based on the specified metric.
1937
+
1938
+ Args:
1939
+ embedding1: First embedding vector (numpy array or list)
1940
+ embedding2: Second embedding vector (numpy array or list)
1941
+ metric: Distance metric ('COSINE', 'L2', 'IP', etc.)
1942
+
1943
+ Returns:
1944
+ Similarity score (higher = more similar)
1945
+ """
1946
+ # Convert to numpy arrays if needed
1947
+ if isinstance(embedding1, list):
1948
+ embedding1 = np.array(embedding1, dtype=np.float32)
1949
+ if isinstance(embedding2, list):
1950
+ embedding2 = np.array(embedding2, dtype=np.float32)
1951
+
1952
+ # Ensure embeddings are numpy arrays
1953
+ if not isinstance(embedding1, np.ndarray):
1954
+ embedding1 = np.array(embedding1, dtype=np.float32)
1955
+ if not isinstance(embedding2, np.ndarray):
1956
+ embedding2 = np.array(embedding2, dtype=np.float32)
1957
+
1958
+ # Ensure embeddings are 2D arrays for sklearn
1959
+ emb1 = embedding1.reshape(1, -1)
1960
+ emb2 = embedding2.reshape(1, -1)
1961
+
1962
+ # Convert string metrics to DistanceStrategy enum if needed
1963
+ if isinstance(metric, str):
1964
+ metric_mapping = {
1965
+ 'COSINE': DistanceStrategy.COSINE,
1966
+ 'L2': DistanceStrategy.EUCLIDEAN_DISTANCE,
1967
+ 'EUCLIDEAN': DistanceStrategy.EUCLIDEAN_DISTANCE,
1968
+ 'IP': DistanceStrategy.MAX_INNER_PRODUCT,
1969
+ 'DOT': DistanceStrategy.DOT_PRODUCT,
1970
+ 'DOT_PRODUCT': DistanceStrategy.DOT_PRODUCT,
1971
+ 'MAX_INNER_PRODUCT': DistanceStrategy.MAX_INNER_PRODUCT
1972
+ }
1973
+ strategy = metric_mapping.get(metric.upper(), DistanceStrategy.COSINE)
1974
+ else:
1975
+ strategy = metric
1976
+
1977
+ if strategy == DistanceStrategy.COSINE:
1978
+ # Cosine similarity (returns similarity, not distance)
1979
+ similarity = cosine_similarity(emb1, emb2)[0, 0]
1980
+ return float(similarity)
1981
+
1982
+ elif strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
1983
+ # Convert Euclidean distance to similarity
1984
+ distance = euclidean_distances(emb1, emb2)[0, 0]
1985
+ similarity = 1.0 / (1.0 + distance)
1986
+ return float(similarity)
1987
+
1988
+ elif strategy in [DistanceStrategy.MAX_INNER_PRODUCT, DistanceStrategy.DOT_PRODUCT]:
1989
+ # Dot product (inner product)
1990
+ similarity = np.dot(embedding1.flatten(), embedding2.flatten())
1991
+ return float(similarity)
1992
+
1993
+ else:
1994
+ # Default to cosine similarity
1995
+ similarity = cosine_similarity(emb1, emb2)[0, 0]
1996
+ return float(similarity)
1997
+
1998
+ async def delete_documents(
1999
+ self,
2000
+ documents: Optional[List[Document]] = None,
2001
+ pk: str = 'source_type',
2002
+ values: Optional[Union[str, List[str]]] = None,
2003
+ table: Optional[str] = None,
2004
+ schema: Optional[str] = None,
2005
+ metadata_column: Optional[str] = None,
2006
+ **kwargs
2007
+ ) -> int:
2008
+ """
2009
+ Delete documents from the vector store based on metadata field values.
2010
+
2011
+ Args:
2012
+ documents: List of documents whose metadata values will be used for deletion.
2013
+ If provided, the pk field values will be extracted from these documents.
2014
+ pk: The metadata field name to use for deletion (default: 'source_type')
2015
+ values: Specific values to delete. Can be a single string or list of strings.
2016
+ If provided, this takes precedence over extracting from documents.
2017
+ table: Override table name
2018
+ schema: Override schema name
2019
+ metadata_column: Override metadata column name
2020
+
2021
+ Returns:
2022
+ int: Number of documents deleted
2023
+
2024
+ Examples:
2025
+ # Delete all documents with source_type 'papers'
2026
+ deleted_count = await store.delete_documents(values='papers')
2027
+
2028
+ # Delete documents with multiple source types
2029
+ deleted_count = await store.delete_documents(values=['papers', 'reports'])
2030
+
2031
+ # Delete based on documents' metadata
2032
+ docs_to_delete = [Document(page_content="test", metadata={"source_type": "papers"})]
2033
+ deleted_count = await store.delete_documents(documents=docs_to_delete)
2034
+
2035
+ # Delete by different metadata field
2036
+ deleted_count = await store.delete_documents(pk='category', values='obsolete')
2037
+ """
2038
+ if not self._connected:
2039
+ await self.connection()
2040
+
2041
+ # Use defaults from instance if not provided
2042
+ table = table or self.table_name
2043
+ schema = schema or self.schema
2044
+ metadata_column = metadata_column or self._document_column or 'cmetadata'
2045
+
2046
+ # Extract values to delete
2047
+ delete_values = []
2048
+
2049
+ if values is not None:
2050
+ # Use provided values
2051
+ if isinstance(values, str):
2052
+ delete_values = [values]
2053
+ else:
2054
+ delete_values = list(values)
2055
+ elif documents:
2056
+ # Extract values from documents metadata
2057
+ for doc in documents:
2058
+ if hasattr(doc, 'metadata') and doc.metadata and pk in doc.metadata:
2059
+ value = doc.metadata[pk]
2060
+ if value and value not in delete_values:
2061
+ delete_values.append(value)
2062
+ else:
2063
+ raise ValueError("Either 'documents' or 'values' parameter must be provided")
2064
+
2065
+ if not delete_values:
2066
+ self.logger.warning(f"No values found for field '{pk}' to delete")
2067
+ return 0
2068
+
2069
+ # Construct full table name
2070
+ full_table_name = f"{schema}.{table}" if schema != 'public' else table
2071
+
2072
+ deleted_count = 0
2073
+
2074
+ try:
2075
+ async with self.session() as session:
2076
+ for value in delete_values:
2077
+ # Use JSONB operator to find matching metadata
2078
+ delete_query = text(f"""
2079
+ DELETE FROM {full_table_name}
2080
+ WHERE {metadata_column}->>:pk = :value
2081
+ """)
2082
+
2083
+ result = await session.execute(
2084
+ delete_query,
2085
+ {"pk": pk, "value": str(value)}
2086
+ )
2087
+
2088
+ rows_deleted = result.rowcount
2089
+ deleted_count += rows_deleted
2090
+
2091
+ self.logger.info(
2092
+ f"Deleted {rows_deleted} documents with {pk}='{value}' from {full_table_name}"
2093
+ )
2094
+
2095
+ self.logger.info(f"Total deleted: {deleted_count} documents")
2096
+ return deleted_count
2097
+
2098
+ except Exception as e:
2099
+ self.logger.error(f"Error deleting documents: {e}")
2100
+ raise RuntimeError(f"Failed to delete documents: {e}") from e
2101
+
2102
+ async def delete_documents_by_filter(
2103
+ self,
2104
+ filter_dict: Dict[str, Union[str, List[str]]],
2105
+ table: Optional[str] = None,
2106
+ schema: Optional[str] = None,
2107
+ metadata_column: Optional[str] = None,
2108
+ **kwargs
2109
+ ) -> int:
2110
+ """
2111
+ Delete documents based on multiple metadata field conditions.
2112
+
2113
+ Args:
2114
+ filter_dict: Dictionary of field_name: value(s) pairs for deletion criteria
2115
+ table: Override table name
2116
+ schema: Override schema name
2117
+ metadata_column: Override metadata column name
2118
+
2119
+ Returns:
2120
+ int: Number of documents deleted
2121
+
2122
+ Example:
2123
+ # Delete documents with source_type='papers' AND category='research'
2124
+ deleted = await store.delete_documents_by_filter({
2125
+ 'source_type': 'papers',
2126
+ 'category': 'research'
2127
+ })
2128
+
2129
+ # Delete documents with source_type in ['papers', 'reports']
2130
+ deleted = await store.delete_documents_by_filter({
2131
+ 'source_type': ['papers', 'reports']
2132
+ })
2133
+ """
2134
+ if not self._connected:
2135
+ await self.connection()
2136
+
2137
+ if not filter_dict:
2138
+ raise ValueError("filter_dict cannot be empty")
2139
+
2140
+ # Use defaults from instance if not provided
2141
+ table = table or self.table_name
2142
+ schema = schema or self.schema
2143
+ metadata_column = metadata_column or self._document_column or 'cmetadata'
2144
+
2145
+ # Construct full table name
2146
+ full_table_name = f"{schema}.{table}" if schema != 'public' else table
2147
+
2148
+ # Build WHERE conditions
2149
+ where_conditions = []
2150
+ params = {}
2151
+
2152
+ for field, values in filter_dict.items():
2153
+ if isinstance(values, (list, tuple)):
2154
+ # Handle multiple values with IN operator
2155
+ placeholders = []
2156
+ for i, value in enumerate(values):
2157
+ param_name = f"{field}_{i}"
2158
+ placeholders.append(f":{param_name}")
2159
+ params[param_name] = str(value)
2160
+
2161
+ condition = f"{metadata_column}->>'{field}' IN ({', '.join(placeholders)})"
2162
+ where_conditions.append(condition)
2163
+ else:
2164
+ # Handle single value
2165
+ param_name = f"{field}_single"
2166
+ where_conditions.append(f"{metadata_column}->>'{field}' = :{param_name}")
2167
+ params[param_name] = str(values)
2168
+
2169
+ # Combine conditions with AND
2170
+ where_clause = " AND ".join(where_conditions)
2171
+
2172
+ delete_query = text(f"""
2173
+ DELETE FROM {full_table_name}
2174
+ WHERE {where_clause}
2175
+ """)
2176
+
2177
+ try:
2178
+ async with self.session() as session:
2179
+ result = await session.execute(delete_query, params)
2180
+ deleted_count = result.rowcount
2181
+
2182
+ self.logger.info(
2183
+ f"Deleted {deleted_count} documents from {full_table_name} "
2184
+ f"with filter: {filter_dict}"
2185
+ )
2186
+
2187
+ return deleted_count
2188
+
2189
+ except Exception as e:
2190
+ self.logger.error(f"Error deleting documents by filter: {e}")
2191
+ raise RuntimeError(f"Failed to delete documents by filter: {e}") from e
2192
+
2193
+ async def delete_all_documents(
2194
+ self,
2195
+ table: Optional[str] = None,
2196
+ schema: Optional[str] = None,
2197
+ confirm: bool = False,
2198
+ **kwargs
2199
+ ) -> int:
2200
+ """
2201
+ Delete ALL documents from the vector store table.
2202
+
2203
+ WARNING: This will delete all data in the table!
2204
+
2205
+ Args:
2206
+ table: Override table name
2207
+ schema: Override schema name
2208
+ confirm: Must be set to True to proceed with deletion
2209
+
2210
+ Returns:
2211
+ int: Number of documents deleted
2212
+ """
2213
+ if not confirm:
2214
+ raise ValueError(
2215
+ "This operation will delete ALL documents. "
2216
+ "Set confirm=True to proceed."
2217
+ )
2218
+
2219
+ if not self._connected:
2220
+ await self.connection()
2221
+
2222
+ # Use defaults from instance if not provided
2223
+ table = table or self.table_name
2224
+ schema = schema or self.schema
2225
+
2226
+ # Construct full table name
2227
+ full_table_name = f"{schema}.{table}" if schema != 'public' else table
2228
+
2229
+ try:
2230
+ async with self.session() as session:
2231
+ # First count existing documents
2232
+ count_query = text(f"SELECT COUNT(*) FROM {full_table_name}")
2233
+ count_result = await session.execute(count_query)
2234
+ total_docs = count_result.scalar()
2235
+
2236
+ if total_docs == 0:
2237
+ self.logger.info(f"No documents to delete from {full_table_name}")
2238
+ return 0
2239
+
2240
+ # Delete all documents
2241
+ delete_query = text(f"DELETE FROM {full_table_name}")
2242
+ result = await session.execute(delete_query)
2243
+ deleted_count = result.rowcount
2244
+
2245
+ self.logger.warning(
2246
+ f"DELETED ALL {deleted_count} documents from {full_table_name}"
2247
+ )
2248
+
2249
+ return deleted_count
2250
+
2251
+ except Exception as e:
2252
+ self.logger.error(f"Error deleting all documents: {e}")
2253
+ raise RuntimeError(f"Failed to delete all documents: {e}") from e
2254
+
2255
+ async def delete_documents_by_ids(
2256
+ self,
2257
+ document_ids: List[str],
2258
+ table: Optional[str] = None,
2259
+ schema: Optional[str] = None,
2260
+ id_column: Optional[str] = None,
2261
+ **kwargs
2262
+ ) -> int:
2263
+ """
2264
+ Delete documents by their IDs.
2265
+
2266
+ Args:
2267
+ document_ids: List of document IDs to delete
2268
+ table: Override table name
2269
+ schema: Override schema name
2270
+ id_column: Override ID column name
2271
+
2272
+ Returns:
2273
+ int: Number of documents deleted
2274
+
2275
+ Example:
2276
+ deleted_count = await store.delete_documents_by_ids([
2277
+ "doc_1", "doc_2", "doc_3"
2278
+ ])
2279
+ """
2280
+ if not self._connected:
2281
+ await self.connection()
2282
+
2283
+ if not document_ids:
2284
+ self.logger.warning("No document IDs provided for deletion")
2285
+ return 0
2286
+
2287
+ # Use defaults from instance if not provided
2288
+ table = table or self.table_name
2289
+ schema = schema or self.schema
2290
+ id_column = id_column or self._id_column
2291
+
2292
+ # Construct full table name
2293
+ full_table_name = f"{schema}.{table}" if schema != 'public' else table
2294
+
2295
+ # Build parameterized query for multiple IDs
2296
+ placeholders = []
2297
+ params = {}
2298
+ for i, doc_id in enumerate(document_ids):
2299
+ param_name = f"id_{i}"
2300
+ placeholders.append(f":{param_name}")
2301
+ params[param_name] = str(doc_id)
2302
+
2303
+ delete_query = text(f"""
2304
+ DELETE FROM {full_table_name}
2305
+ WHERE {id_column} IN ({', '.join(placeholders)})
2306
+ """)
2307
+
2308
+ try:
2309
+ async with self.session() as session:
2310
+ result = await session.execute(delete_query, params)
2311
+ deleted_count = result.rowcount
2312
+
2313
+ self.logger.info(
2314
+ f"Deleted {deleted_count} documents by IDs from {full_table_name}"
2315
+ )
2316
+
2317
+ return deleted_count
2318
+
2319
+ except Exception as e:
2320
+ self.logger.error(f"Error deleting documents by IDs: {e}")
2321
+ raise RuntimeError(f"Failed to delete documents by IDs: {e}") from e
2322
+
2323
+ # Additional utility method for safer deletions
2324
+ async def count_documents_by_filter(
2325
+ self,
2326
+ filter_dict: Dict[str, Union[str, List[str]]],
2327
+ table: Optional[str] = None,
2328
+ schema: Optional[str] = None,
2329
+ metadata_column: Optional[str] = None,
2330
+ **kwargs
2331
+ ) -> int:
2332
+ """
2333
+ Count documents that would be affected by a filter (useful before deletion).
2334
+
2335
+ Args:
2336
+ filter_dict: Dictionary of field_name: value(s) pairs for criteria
2337
+ table: Override table name
2338
+ schema: Override schema name
2339
+ metadata_column: Override metadata column name
2340
+
2341
+ Returns:
2342
+ int: Number of documents matching the filter
2343
+ """
2344
+ if not self._connected:
2345
+ await self.connection()
2346
+
2347
+ if not filter_dict:
2348
+ return 0
2349
+
2350
+ # Use defaults from instance if not provided
2351
+ table = table or self.table_name
2352
+ schema = schema or self.schema
2353
+ metadata_column = metadata_column or self._document_column or 'cmetadata'
2354
+
2355
+ # Construct full table name
2356
+ full_table_name = f"{schema}.{table}" if schema != 'public' else table
2357
+
2358
+ # Build WHERE conditions (same logic as delete_documents_by_filter)
2359
+ where_conditions = []
2360
+ params = {}
2361
+
2362
+ for field, values in filter_dict.items():
2363
+ if isinstance(values, (list, tuple)):
2364
+ placeholders = []
2365
+ for i, value in enumerate(values):
2366
+ param_name = f"{field}_{i}"
2367
+ placeholders.append(f":{param_name}")
2368
+ params[param_name] = str(value)
2369
+
2370
+ condition = f"{metadata_column}->>'{field}' IN ({', '.join(placeholders)})"
2371
+ where_conditions.append(condition)
2372
+ else:
2373
+ param_name = f"{field}_single"
2374
+ where_conditions.append(f"{metadata_column}->>'{field}' = :{param_name}")
2375
+ params[param_name] = str(values)
2376
+
2377
+ where_clause = " AND ".join(where_conditions)
2378
+ count_query = text(f"""
2379
+ SELECT COUNT(*) FROM {full_table_name}
2380
+ WHERE {where_clause}
2381
+ """)
2382
+
2383
+ try:
2384
+ async with self.session() as session:
2385
+ result = await session.execute(count_query, params)
2386
+ count = result.scalar()
2387
+
2388
+ self.logger.info(
2389
+ f"Found {count} documents matching filter: {filter_dict}"
2390
+ )
2391
+
2392
+ return count
2393
+
2394
+ except Exception as e:
2395
+ self.logger.error(f"Error counting documents: {e}")
2396
+ raise RuntimeError(f"Failed to count documents: {e}") from e
2397
+
2398
+ async def from_documents(
2399
+ self,
2400
+ documents: List[Document],
2401
+ table: str = None,
2402
+ schema: str = None,
2403
+ embedding_column: str = 'embedding',
2404
+ content_column: str = 'document',
2405
+ metadata_column: str = 'cmetadata',
2406
+ chunk_size: int = 8192,
2407
+ chunk_overlap: int = 200,
2408
+ store_full_document: bool = True,
2409
+ **kwargs
2410
+ ) -> Dict[str, Any]:
2411
+ """
2412
+ Add documents using late chunking strategy.
2413
+
2414
+ Args:
2415
+ documents: List of Document objects to process
2416
+ table: Table name
2417
+ schema: Schema name
2418
+ embedding_column: Name of embedding column
2419
+ content_column: Name of content column
2420
+ metadata_column: Name of metadata column
2421
+ chunk_size: Maximum size of each chunk
2422
+ chunk_overlap: Overlap between chunks
2423
+ store_full_document: Whether to store full document alongside chunks
2424
+
2425
+ Returns:
2426
+ Dictionary with processing statistics
2427
+ """
2428
+ if not self._connected:
2429
+ await self.connection()
2430
+
2431
+ table = table or self.table_name
2432
+ schema = schema or self.schema
2433
+
2434
+
2435
+ # Initialize late chunking processor
2436
+ chunking_processor = LateChunkingProcessor(
2437
+ vector_store=self,
2438
+ chunk_size=chunk_size,
2439
+ chunk_overlap=chunk_overlap
2440
+ )
2441
+
2442
+ # Ensure embedding store is initialized
2443
+ if self.embedding_store is None:
2444
+ self.embedding_store = self._define_collection_store(
2445
+ table=table,
2446
+ schema=schema,
2447
+ dimension=self.dimension,
2448
+ id_column=self._id_column,
2449
+ embedding_column=embedding_column,
2450
+ document_column=content_column,
2451
+ metadata_column=metadata_column,
2452
+ text_column=self._text_column,
2453
+ )
2454
+
2455
+ all_inserts = []
2456
+ stats = {
2457
+ 'documents_processed': 0,
2458
+ 'chunks_created': 0,
2459
+ 'full_documents_stored': 0
2460
+ }
2461
+ for doc_idx, document in enumerate(documents):
2462
+ document_id = f"doc_{doc_idx:06d}_{uuid.uuid4().hex[:8]}"
2463
+
2464
+ # Process document with late chunking
2465
+ full_embedding, chunk_infos = await chunking_processor.process_document_late_chunking(
2466
+ document_text=document.page_content,
2467
+ document_id=document_id,
2468
+ metadata=document.metadata
2469
+ )
2470
+ # Store full document if requested
2471
+ if store_full_document:
2472
+ full_doc_metadata = {
2473
+ **(document.metadata or {}),
2474
+ 'document_id': document_id,
2475
+ 'is_full_document': True,
2476
+ 'total_chunks': len(chunk_infos),
2477
+ 'document_type': 'parent',
2478
+ 'chunking_strategy': 'late_chunking'
2479
+ }
2480
+
2481
+ all_inserts.append({
2482
+ self._id_column: document_id,
2483
+ embedding_column: full_embedding.tolist(),
2484
+ content_column: document.page_content,
2485
+ metadata_column: full_doc_metadata
2486
+ })
2487
+ stats['full_documents_stored'] += 1
2488
+
2489
+ # Store all chunks
2490
+ for chunk_info in chunk_infos:
2491
+ embed = chunk_info.chunk_embedding if isinstance(chunk_info.chunk_embedding, list) else chunk_info.chunk_embedding.tolist()
2492
+ all_inserts.append({
2493
+ self._id_column: chunk_info.chunk_id,
2494
+ embedding_column: embed,
2495
+ content_column: chunk_info.chunk_text,
2496
+ metadata_column: chunk_info.metadata
2497
+ })
2498
+ stats['chunks_created'] += 1
2499
+
2500
+ stats['documents_processed'] += 1
2501
+
2502
+ # Bulk insert all data
2503
+ if all_inserts:
2504
+ insert_stmt = insert(self.embedding_store)
2505
+
2506
+ try:
2507
+ async with self.session() as session:
2508
+ await session.execute(insert_stmt, all_inserts)
2509
+
2510
+ self.logger.info(
2511
+ f"✅ Late chunking complete: {stats['documents_processed']} documents → "
2512
+ f"{stats['chunks_created']} chunks + {stats['full_documents_stored']} full docs"
2513
+ )
2514
+
2515
+ except Exception as e:
2516
+ self.logger.error(f"Error in late chunking insert: {e}")
2517
+ raise
2518
+
2519
+ return stats
2520
+
2521
+ async def document_search(
2522
+ self,
2523
+ query: str,
2524
+ table: str = None,
2525
+ schema: str = None,
2526
+ limit: int = 10,
2527
+ search_chunks: bool = True,
2528
+ search_full_docs: bool = False,
2529
+ rerank_with_context: bool = True,
2530
+ context_window: int = 2,
2531
+ **kwargs
2532
+ ) -> List[SearchResult]:
2533
+ """
2534
+ Search with late chunking context awareness.
2535
+
2536
+ Args:
2537
+ query: Search query
2538
+ table: Table name
2539
+ schema: Schema name
2540
+ limit: Number of results
2541
+ search_chunks: Whether to search chunk-level embeddings
2542
+ search_full_docs: Whether to search full document embeddings
2543
+ rerank_with_context: Whether to rerank using surrounding chunks
2544
+ context_window: Number of adjacent chunks to include for context
2545
+
2546
+ Returns:
2547
+ List of SearchResult objects with enhanced context
2548
+ """
2549
+ results = []
2550
+
2551
+ # Search chunks if requested
2552
+ if search_chunks:
2553
+ chunk_filters = {'is_chunk': True}
2554
+ chunk_results = await self.similarity_search(
2555
+ query=query,
2556
+ table=table,
2557
+ schema=schema,
2558
+ limit=limit * 2, # Get more candidates for reranking
2559
+ metadata_filters=chunk_filters,
2560
+ **kwargs
2561
+ )
2562
+ results.extend(chunk_results)
2563
+
2564
+ # Search full documents if requested
2565
+ if search_full_docs:
2566
+ doc_filters = {'is_full_document': True}
2567
+ doc_results = await self.similarity_search(
2568
+ query=query,
2569
+ table=table,
2570
+ schema=schema,
2571
+ limit=limit,
2572
+ metadata_filters=doc_filters,
2573
+ **kwargs
2574
+ )
2575
+ results.extend(doc_results)
2576
+
2577
+ # Rerank with context if requested
2578
+ if rerank_with_context and search_chunks:
2579
+ results = await self._rerank_with_chunk_context(
2580
+ results=results,
2581
+ query=query,
2582
+ table=table,
2583
+ schema=schema,
2584
+ context_window=context_window
2585
+ )
2586
+
2587
+ # Sort by score and limit
2588
+ results.sort(key=lambda x: x.score)
2589
+ return results[:limit]
2590
+
2591
+ async def _rerank_with_chunk_context(
2592
+ self,
2593
+ results: List[SearchResult],
2594
+ query: str,
2595
+ table: str,
2596
+ schema: str,
2597
+ context_window: int = 2
2598
+ ) -> List[SearchResult]:
2599
+ """
2600
+ Rerank results by including surrounding chunk context.
2601
+ """
2602
+ enhanced_results = []
2603
+
2604
+ for result in results:
2605
+ if not result.metadata.get('is_chunk'):
2606
+ enhanced_results.append(result)
2607
+ continue
2608
+
2609
+ # Get surrounding chunks for context
2610
+ parent_id = result.metadata.get('parent_document_id')
2611
+ chunk_index = result.metadata.get('chunk_index', 0)
2612
+
2613
+ if parent_id:
2614
+ try:
2615
+ # Find adjacent chunks
2616
+ context_chunks = await self._get_adjacent_chunks(
2617
+ parent_id=parent_id,
2618
+ center_chunk_index=chunk_index,
2619
+ window_size=context_window,
2620
+ table=table,
2621
+ schema=schema
2622
+ )
2623
+
2624
+ # Combine text with context
2625
+ combined_text = self._combine_chunk_context(result, context_chunks)
2626
+
2627
+ # Re-score with context - ensure embeddings are consistent
2628
+ context_embedding = self._embed_.embed_query(combined_text)
2629
+ query_embedding = self._embed_.embed_query(query)
2630
+
2631
+ # Ensure both embeddings are numpy arrays
2632
+ if isinstance(context_embedding, list):
2633
+ context_embedding = np.array(context_embedding, dtype=np.float32)
2634
+ if isinstance(query_embedding, list):
2635
+ query_embedding = np.array(query_embedding, dtype=np.float32)
2636
+
2637
+ # Calculate new similarity score
2638
+ context_score = self._compute_similarity(
2639
+ query_embedding, context_embedding, self.distance_strategy
2640
+ )
2641
+
2642
+ # Create enhanced result
2643
+ enhanced_metadata = dict(result.metadata)
2644
+ enhanced_metadata['context_score'] = context_score
2645
+ enhanced_metadata['has_context'] = True
2646
+ enhanced_metadata['context_chunks'] = len(context_chunks)
2647
+
2648
+ enhanced_result = SearchResult(
2649
+ id=result.id,
2650
+ content=combined_text,
2651
+ metadata=enhanced_metadata,
2652
+ score=context_score
2653
+ )
2654
+
2655
+ enhanced_results.append(enhanced_result)
2656
+
2657
+ except Exception as e:
2658
+ self.logger.warning(f"Error reranking chunk {result.id}: {e}")
2659
+ # Fall back to original result if reranking fails
2660
+ enhanced_results.append(result)
2661
+ else:
2662
+ enhanced_results.append(result)
2663
+
2664
+ return enhanced_results
2665
+
2666
+ async def _get_adjacent_chunks(
2667
+ self,
2668
+ parent_id: str,
2669
+ center_chunk_index: int,
2670
+ window_size: int,
2671
+ table: str,
2672
+ schema: str
2673
+ ) -> List[SearchResult]:
2674
+ """Get adjacent chunks for context."""
2675
+ # Calculate chunk index range
2676
+ start_idx = max(0, center_chunk_index - window_size)
2677
+ end_idx = center_chunk_index + window_size + 1
2678
+
2679
+ # Search for chunks in the range
2680
+ chunk_filters = {
2681
+ 'parent_document_id': parent_id,
2682
+ 'is_chunk': True
2683
+ }
2684
+
2685
+ # Get all chunks from parent document
2686
+ all_chunks = await self.similarity_search(
2687
+ query="dummy",
2688
+ table=table,
2689
+ schema=schema,
2690
+ limit=1000, # High limit to get all chunks
2691
+ metadata_filters=chunk_filters
2692
+ )
2693
+
2694
+ # Filter to adjacent chunks
2695
+ adjacent_chunks = [
2696
+ chunk for chunk in all_chunks
2697
+ if start_idx <= chunk.metadata.get('chunk_index', 0) < end_idx
2698
+ ]
2699
+
2700
+ # Sort by chunk index
2701
+ adjacent_chunks.sort(key=lambda x: x.metadata.get('chunk_index', 0))
2702
+
2703
+ return adjacent_chunks
2704
+
2705
+ def _combine_chunk_context(
2706
+ self,
2707
+ center_result: SearchResult,
2708
+ context_chunks: List[SearchResult]
2709
+ ) -> str:
2710
+ """Combine center chunk with surrounding context."""
2711
+ # Sort context chunks by index
2712
+ context_chunks.sort(key=lambda x: x.metadata.get('chunk_index', 0))
2713
+
2714
+ # Combine text
2715
+ combined_parts = []
2716
+ center_idx = center_result.metadata.get('chunk_index', 0)
2717
+
2718
+ for chunk in context_chunks:
2719
+ chunk_idx = chunk.metadata.get('chunk_index', 0)
2720
+ if chunk_idx == center_idx:
2721
+ # Mark the main chunk
2722
+ combined_parts.append(f"[MAIN] {chunk.content} [/MAIN]")
2723
+ else:
2724
+ combined_parts.append(chunk.content)
2725
+
2726
+ return " ... ".join(combined_parts)
2727
+
2728
+ async def collection_exists(self, table: str, schema: str = 'public') -> bool:
2729
+ """
2730
+ Check if a collection (table) exists in the database.
2731
+
2732
+ Args:
2733
+ table: Name of the table to check
2734
+ schema: Schema name (default: 'public')
2735
+
2736
+ Returns:
2737
+ bool: True if the collection exists, False otherwise
2738
+ """
2739
+ if not self._connected:
2740
+ await self.connection()
2741
+
2742
+ async with self.session() as session:
2743
+ query = text(f"""
2744
+ SELECT EXISTS (
2745
+ SELECT 1 FROM information_schema.tables
2746
+ WHERE table_schema = :schema AND table_name = :table
2747
+ )
2748
+ """)
2749
+ result = await session.execute(query, {"schema": schema, "table": table})
2750
+ return result.scalar()
2751
+ return False
2752
+
2753
+ async def delete_collection(
2754
+ self,
2755
+ table: str,
2756
+ schema: str = 'public'
2757
+ ) -> None:
2758
+ """
2759
+ Delete a collection (table) from the database.
2760
+
2761
+ Args:
2762
+ table: Name of the table to delete
2763
+ schema: Schema name (default: 'public')
2764
+
2765
+ Raises:
2766
+ RuntimeError: If the collection does not exist or deletion fails
2767
+ """
2768
+ if not self._connected:
2769
+ await self.connection()
2770
+
2771
+ if not await self.collection_exists(table, schema):
2772
+ raise RuntimeError(
2773
+ f"Collection {schema}.{table} does not exist"
2774
+ )
2775
+
2776
+ async with self.session() as session:
2777
+ query = text(
2778
+ f"DROP TABLE IF EXISTS {schema}.{table} CASCADE"
2779
+ )
2780
+ await session.execute(query)
2781
+ self.logger.info(
2782
+ f"Collection {schema}.{table} deleted successfully"
2783
+ )
2784
+
2785
+ async def create_collection(
2786
+ self,
2787
+ table: str,
2788
+ schema: str = 'public',
2789
+ dimension: int = 768,
2790
+ index_type: str = "COSINE",
2791
+ metric_type: str = 'L2',
2792
+ id_column: Optional[str] = None,
2793
+ **kwargs
2794
+ ) -> None:
2795
+ """
2796
+ Create a new collection (table) in the database.
2797
+
2798
+ Args:
2799
+ table: Name of the table to create
2800
+ schema: Schema name (default: 'public')
2801
+ dimension: Embedding dimension (default: 768)
2802
+ index_type: Type of index to create (default: "COSINE")
2803
+ metric_type: Distance metric type (default: 'L2')
2804
+ id_column: Name of the ID column (default: 'id')
2805
+ embedding_column: Name of the embedding column (default: 'embedding')
2806
+ document_column: Name of the document content column (default: 'document')
2807
+ metadata_column: Name of the metadata column (default: 'cmetadata')
2808
+
2809
+ Raises:
2810
+ RuntimeError: If collection creation fails
2811
+ """
2812
+ if not self._connected:
2813
+ await self.connection()
2814
+
2815
+ # Construct full table name
2816
+ full_table_name = f"{schema}.{table}" if schema != 'public' else table
2817
+ self._metric_type: str = metric_type.upper()
2818
+ self._index_type: str = index_type.upper()
2819
+ try:
2820
+ async with self.session() as session:
2821
+ # Check if collection already exists
2822
+ if await self.collection_exists(table, schema):
2823
+ self.logger.info(
2824
+ f"Collection {schema}.{table} already exists"
2825
+ )
2826
+ else:
2827
+ id_column = id_column or self._id_column or 'id'
2828
+ # Create the collection:
2829
+ self.logger.info(f"Creating collection {schema}.{table}...")
2830
+ create_query = text(f"""
2831
+ CREATE TABLE {full_table_name} (
2832
+ {id_column} TEXT PRIMARY KEY
2833
+ )
2834
+ """)
2835
+ await session.execute(create_query)
2836
+ self.logger.info(
2837
+ f"Collection {schema}.{table} created successfully"
2838
+ )
2839
+ # Execute prepare:
2840
+ await self.prepare_embedding_table(
2841
+ table=table,
2842
+ schema=schema,
2843
+ conn=session,
2844
+ dimension=dimension,
2845
+ id_column=id_column,
2846
+ create_all_indexes=True,
2847
+ **kwargs
2848
+ )
2849
+ except Exception as e:
2850
+ self.logger.error(f"Error creating collection: {e}")
2851
+ raise RuntimeError(
2852
+ f"Failed to create collection: {e}"
2853
+ ) from e