ag2 0.9.1__py3-none-any.whl → 0.9.1.post0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ag2 might be problematic. Click here for more details.
- {ag2-0.9.1.dist-info → ag2-0.9.1.post0.dist-info}/METADATA +264 -73
- ag2-0.9.1.post0.dist-info/RECORD +392 -0
- {ag2-0.9.1.dist-info → ag2-0.9.1.post0.dist-info}/WHEEL +1 -2
- autogen/__init__.py +89 -0
- autogen/_website/__init__.py +3 -0
- autogen/_website/generate_api_references.py +427 -0
- autogen/_website/generate_mkdocs.py +1174 -0
- autogen/_website/notebook_processor.py +476 -0
- autogen/_website/process_notebooks.py +656 -0
- autogen/_website/utils.py +412 -0
- autogen/agentchat/__init__.py +44 -0
- autogen/agentchat/agent.py +182 -0
- autogen/agentchat/assistant_agent.py +85 -0
- autogen/agentchat/chat.py +309 -0
- autogen/agentchat/contrib/__init__.py +5 -0
- autogen/agentchat/contrib/agent_eval/README.md +7 -0
- autogen/agentchat/contrib/agent_eval/agent_eval.py +108 -0
- autogen/agentchat/contrib/agent_eval/criterion.py +43 -0
- autogen/agentchat/contrib/agent_eval/critic_agent.py +44 -0
- autogen/agentchat/contrib/agent_eval/quantifier_agent.py +39 -0
- autogen/agentchat/contrib/agent_eval/subcritic_agent.py +45 -0
- autogen/agentchat/contrib/agent_eval/task.py +42 -0
- autogen/agentchat/contrib/agent_optimizer.py +429 -0
- autogen/agentchat/contrib/capabilities/__init__.py +5 -0
- autogen/agentchat/contrib/capabilities/agent_capability.py +20 -0
- autogen/agentchat/contrib/capabilities/generate_images.py +301 -0
- autogen/agentchat/contrib/capabilities/teachability.py +393 -0
- autogen/agentchat/contrib/capabilities/text_compressors.py +66 -0
- autogen/agentchat/contrib/capabilities/tools_capability.py +22 -0
- autogen/agentchat/contrib/capabilities/transform_messages.py +93 -0
- autogen/agentchat/contrib/capabilities/transforms.py +566 -0
- autogen/agentchat/contrib/capabilities/transforms_util.py +122 -0
- autogen/agentchat/contrib/capabilities/vision_capability.py +214 -0
- autogen/agentchat/contrib/captainagent/__init__.py +9 -0
- autogen/agentchat/contrib/captainagent/agent_builder.py +790 -0
- autogen/agentchat/contrib/captainagent/captainagent.py +512 -0
- autogen/agentchat/contrib/captainagent/tool_retriever.py +335 -0
- autogen/agentchat/contrib/captainagent/tools/README.md +44 -0
- autogen/agentchat/contrib/captainagent/tools/__init__.py +5 -0
- autogen/agentchat/contrib/captainagent/tools/data_analysis/calculate_correlation.py +40 -0
- autogen/agentchat/contrib/captainagent/tools/data_analysis/calculate_skewness_and_kurtosis.py +28 -0
- autogen/agentchat/contrib/captainagent/tools/data_analysis/detect_outlier_iqr.py +28 -0
- autogen/agentchat/contrib/captainagent/tools/data_analysis/detect_outlier_zscore.py +28 -0
- autogen/agentchat/contrib/captainagent/tools/data_analysis/explore_csv.py +21 -0
- autogen/agentchat/contrib/captainagent/tools/data_analysis/shapiro_wilk_test.py +30 -0
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/arxiv_download.py +27 -0
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/arxiv_search.py +53 -0
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/extract_pdf_image.py +53 -0
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/extract_pdf_text.py +38 -0
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/get_wikipedia_text.py +21 -0
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/get_youtube_caption.py +34 -0
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/image_qa.py +60 -0
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/optical_character_recognition.py +61 -0
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/perform_web_search.py +47 -0
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/scrape_wikipedia_tables.py +33 -0
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/transcribe_audio_file.py +21 -0
- autogen/agentchat/contrib/captainagent/tools/information_retrieval/youtube_download.py +35 -0
- autogen/agentchat/contrib/captainagent/tools/math/calculate_circle_area_from_diameter.py +21 -0
- autogen/agentchat/contrib/captainagent/tools/math/calculate_day_of_the_week.py +18 -0
- autogen/agentchat/contrib/captainagent/tools/math/calculate_fraction_sum.py +28 -0
- autogen/agentchat/contrib/captainagent/tools/math/calculate_matrix_power.py +31 -0
- autogen/agentchat/contrib/captainagent/tools/math/calculate_reflected_point.py +16 -0
- autogen/agentchat/contrib/captainagent/tools/math/complex_numbers_product.py +25 -0
- autogen/agentchat/contrib/captainagent/tools/math/compute_currency_conversion.py +23 -0
- autogen/agentchat/contrib/captainagent/tools/math/count_distinct_permutations.py +27 -0
- autogen/agentchat/contrib/captainagent/tools/math/evaluate_expression.py +28 -0
- autogen/agentchat/contrib/captainagent/tools/math/find_continuity_point.py +34 -0
- autogen/agentchat/contrib/captainagent/tools/math/fraction_to_mixed_numbers.py +39 -0
- autogen/agentchat/contrib/captainagent/tools/math/modular_inverse_sum.py +23 -0
- autogen/agentchat/contrib/captainagent/tools/math/simplify_mixed_numbers.py +36 -0
- autogen/agentchat/contrib/captainagent/tools/math/sum_of_digit_factorials.py +15 -0
- autogen/agentchat/contrib/captainagent/tools/math/sum_of_primes_below.py +15 -0
- autogen/agentchat/contrib/captainagent/tools/requirements.txt +10 -0
- autogen/agentchat/contrib/captainagent/tools/tool_description.tsv +34 -0
- autogen/agentchat/contrib/gpt_assistant_agent.py +526 -0
- autogen/agentchat/contrib/graph_rag/__init__.py +9 -0
- autogen/agentchat/contrib/graph_rag/document.py +29 -0
- autogen/agentchat/contrib/graph_rag/falkor_graph_query_engine.py +170 -0
- autogen/agentchat/contrib/graph_rag/falkor_graph_rag_capability.py +103 -0
- autogen/agentchat/contrib/graph_rag/graph_query_engine.py +53 -0
- autogen/agentchat/contrib/graph_rag/graph_rag_capability.py +63 -0
- autogen/agentchat/contrib/graph_rag/neo4j_graph_query_engine.py +268 -0
- autogen/agentchat/contrib/graph_rag/neo4j_graph_rag_capability.py +83 -0
- autogen/agentchat/contrib/graph_rag/neo4j_native_graph_query_engine.py +210 -0
- autogen/agentchat/contrib/graph_rag/neo4j_native_graph_rag_capability.py +93 -0
- autogen/agentchat/contrib/img_utils.py +397 -0
- autogen/agentchat/contrib/llamaindex_conversable_agent.py +117 -0
- autogen/agentchat/contrib/llava_agent.py +187 -0
- autogen/agentchat/contrib/math_user_proxy_agent.py +464 -0
- autogen/agentchat/contrib/multimodal_conversable_agent.py +125 -0
- autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py +324 -0
- autogen/agentchat/contrib/rag/__init__.py +10 -0
- autogen/agentchat/contrib/rag/chromadb_query_engine.py +272 -0
- autogen/agentchat/contrib/rag/llamaindex_query_engine.py +198 -0
- autogen/agentchat/contrib/rag/mongodb_query_engine.py +329 -0
- autogen/agentchat/contrib/rag/query_engine.py +74 -0
- autogen/agentchat/contrib/retrieve_assistant_agent.py +56 -0
- autogen/agentchat/contrib/retrieve_user_proxy_agent.py +703 -0
- autogen/agentchat/contrib/society_of_mind_agent.py +199 -0
- autogen/agentchat/contrib/swarm_agent.py +1425 -0
- autogen/agentchat/contrib/text_analyzer_agent.py +79 -0
- autogen/agentchat/contrib/vectordb/__init__.py +5 -0
- autogen/agentchat/contrib/vectordb/base.py +232 -0
- autogen/agentchat/contrib/vectordb/chromadb.py +315 -0
- autogen/agentchat/contrib/vectordb/couchbase.py +407 -0
- autogen/agentchat/contrib/vectordb/mongodb.py +550 -0
- autogen/agentchat/contrib/vectordb/pgvectordb.py +928 -0
- autogen/agentchat/contrib/vectordb/qdrant.py +320 -0
- autogen/agentchat/contrib/vectordb/utils.py +126 -0
- autogen/agentchat/contrib/web_surfer.py +303 -0
- autogen/agentchat/conversable_agent.py +4020 -0
- autogen/agentchat/group/__init__.py +64 -0
- autogen/agentchat/group/available_condition.py +91 -0
- autogen/agentchat/group/context_condition.py +77 -0
- autogen/agentchat/group/context_expression.py +238 -0
- autogen/agentchat/group/context_str.py +41 -0
- autogen/agentchat/group/context_variables.py +192 -0
- autogen/agentchat/group/group_tool_executor.py +202 -0
- autogen/agentchat/group/group_utils.py +591 -0
- autogen/agentchat/group/handoffs.py +244 -0
- autogen/agentchat/group/llm_condition.py +93 -0
- autogen/agentchat/group/multi_agent_chat.py +237 -0
- autogen/agentchat/group/on_condition.py +58 -0
- autogen/agentchat/group/on_context_condition.py +54 -0
- autogen/agentchat/group/patterns/__init__.py +18 -0
- autogen/agentchat/group/patterns/auto.py +159 -0
- autogen/agentchat/group/patterns/manual.py +176 -0
- autogen/agentchat/group/patterns/pattern.py +288 -0
- autogen/agentchat/group/patterns/random.py +106 -0
- autogen/agentchat/group/patterns/round_robin.py +117 -0
- autogen/agentchat/group/reply_result.py +26 -0
- autogen/agentchat/group/speaker_selection_result.py +41 -0
- autogen/agentchat/group/targets/__init__.py +4 -0
- autogen/agentchat/group/targets/group_chat_target.py +132 -0
- autogen/agentchat/group/targets/group_manager_target.py +151 -0
- autogen/agentchat/group/targets/transition_target.py +413 -0
- autogen/agentchat/group/targets/transition_utils.py +6 -0
- autogen/agentchat/groupchat.py +1694 -0
- autogen/agentchat/realtime/__init__.py +3 -0
- autogen/agentchat/realtime/experimental/__init__.py +20 -0
- autogen/agentchat/realtime/experimental/audio_adapters/__init__.py +8 -0
- autogen/agentchat/realtime/experimental/audio_adapters/twilio_audio_adapter.py +148 -0
- autogen/agentchat/realtime/experimental/audio_adapters/websocket_audio_adapter.py +139 -0
- autogen/agentchat/realtime/experimental/audio_observer.py +42 -0
- autogen/agentchat/realtime/experimental/clients/__init__.py +15 -0
- autogen/agentchat/realtime/experimental/clients/gemini/__init__.py +7 -0
- autogen/agentchat/realtime/experimental/clients/gemini/client.py +274 -0
- autogen/agentchat/realtime/experimental/clients/oai/__init__.py +8 -0
- autogen/agentchat/realtime/experimental/clients/oai/base_client.py +220 -0
- autogen/agentchat/realtime/experimental/clients/oai/rtc_client.py +243 -0
- autogen/agentchat/realtime/experimental/clients/oai/utils.py +48 -0
- autogen/agentchat/realtime/experimental/clients/realtime_client.py +190 -0
- autogen/agentchat/realtime/experimental/function_observer.py +85 -0
- autogen/agentchat/realtime/experimental/realtime_agent.py +158 -0
- autogen/agentchat/realtime/experimental/realtime_events.py +42 -0
- autogen/agentchat/realtime/experimental/realtime_observer.py +100 -0
- autogen/agentchat/realtime/experimental/realtime_swarm.py +475 -0
- autogen/agentchat/realtime/experimental/websockets.py +21 -0
- autogen/agentchat/realtime_agent/__init__.py +21 -0
- autogen/agentchat/user_proxy_agent.py +111 -0
- autogen/agentchat/utils.py +206 -0
- autogen/agents/__init__.py +3 -0
- autogen/agents/contrib/__init__.py +10 -0
- autogen/agents/contrib/time/__init__.py +8 -0
- autogen/agents/contrib/time/time_reply_agent.py +73 -0
- autogen/agents/contrib/time/time_tool_agent.py +51 -0
- autogen/agents/experimental/__init__.py +27 -0
- autogen/agents/experimental/deep_research/__init__.py +7 -0
- autogen/agents/experimental/deep_research/deep_research.py +52 -0
- autogen/agents/experimental/discord/__init__.py +7 -0
- autogen/agents/experimental/discord/discord.py +66 -0
- autogen/agents/experimental/document_agent/__init__.py +19 -0
- autogen/agents/experimental/document_agent/chroma_query_engine.py +316 -0
- autogen/agents/experimental/document_agent/docling_doc_ingest_agent.py +118 -0
- autogen/agents/experimental/document_agent/document_agent.py +461 -0
- autogen/agents/experimental/document_agent/document_conditions.py +50 -0
- autogen/agents/experimental/document_agent/document_utils.py +380 -0
- autogen/agents/experimental/document_agent/inmemory_query_engine.py +220 -0
- autogen/agents/experimental/document_agent/parser_utils.py +130 -0
- autogen/agents/experimental/document_agent/url_utils.py +426 -0
- autogen/agents/experimental/reasoning/__init__.py +7 -0
- autogen/agents/experimental/reasoning/reasoning_agent.py +1178 -0
- autogen/agents/experimental/slack/__init__.py +7 -0
- autogen/agents/experimental/slack/slack.py +73 -0
- autogen/agents/experimental/telegram/__init__.py +7 -0
- autogen/agents/experimental/telegram/telegram.py +77 -0
- autogen/agents/experimental/websurfer/__init__.py +7 -0
- autogen/agents/experimental/websurfer/websurfer.py +62 -0
- autogen/agents/experimental/wikipedia/__init__.py +7 -0
- autogen/agents/experimental/wikipedia/wikipedia.py +90 -0
- autogen/browser_utils.py +309 -0
- autogen/cache/__init__.py +10 -0
- autogen/cache/abstract_cache_base.py +75 -0
- autogen/cache/cache.py +203 -0
- autogen/cache/cache_factory.py +88 -0
- autogen/cache/cosmos_db_cache.py +144 -0
- autogen/cache/disk_cache.py +102 -0
- autogen/cache/in_memory_cache.py +58 -0
- autogen/cache/redis_cache.py +123 -0
- autogen/code_utils.py +596 -0
- autogen/coding/__init__.py +22 -0
- autogen/coding/base.py +119 -0
- autogen/coding/docker_commandline_code_executor.py +268 -0
- autogen/coding/factory.py +47 -0
- autogen/coding/func_with_reqs.py +202 -0
- autogen/coding/jupyter/__init__.py +23 -0
- autogen/coding/jupyter/base.py +36 -0
- autogen/coding/jupyter/docker_jupyter_server.py +167 -0
- autogen/coding/jupyter/embedded_ipython_code_executor.py +182 -0
- autogen/coding/jupyter/import_utils.py +82 -0
- autogen/coding/jupyter/jupyter_client.py +231 -0
- autogen/coding/jupyter/jupyter_code_executor.py +160 -0
- autogen/coding/jupyter/local_jupyter_server.py +172 -0
- autogen/coding/local_commandline_code_executor.py +405 -0
- autogen/coding/markdown_code_extractor.py +45 -0
- autogen/coding/utils.py +56 -0
- autogen/doc_utils.py +34 -0
- autogen/events/__init__.py +7 -0
- autogen/events/agent_events.py +1010 -0
- autogen/events/base_event.py +99 -0
- autogen/events/client_events.py +167 -0
- autogen/events/helpers.py +36 -0
- autogen/events/print_event.py +46 -0
- autogen/exception_utils.py +73 -0
- autogen/extensions/__init__.py +5 -0
- autogen/fast_depends/__init__.py +16 -0
- autogen/fast_depends/_compat.py +80 -0
- autogen/fast_depends/core/__init__.py +14 -0
- autogen/fast_depends/core/build.py +225 -0
- autogen/fast_depends/core/model.py +576 -0
- autogen/fast_depends/dependencies/__init__.py +15 -0
- autogen/fast_depends/dependencies/model.py +29 -0
- autogen/fast_depends/dependencies/provider.py +39 -0
- autogen/fast_depends/library/__init__.py +10 -0
- autogen/fast_depends/library/model.py +46 -0
- autogen/fast_depends/py.typed +6 -0
- autogen/fast_depends/schema.py +66 -0
- autogen/fast_depends/use.py +280 -0
- autogen/fast_depends/utils.py +187 -0
- autogen/formatting_utils.py +83 -0
- autogen/function_utils.py +13 -0
- autogen/graph_utils.py +178 -0
- autogen/import_utils.py +526 -0
- autogen/interop/__init__.py +22 -0
- autogen/interop/crewai/__init__.py +7 -0
- autogen/interop/crewai/crewai.py +88 -0
- autogen/interop/interoperability.py +71 -0
- autogen/interop/interoperable.py +46 -0
- autogen/interop/langchain/__init__.py +8 -0
- autogen/interop/langchain/langchain_chat_model_factory.py +155 -0
- autogen/interop/langchain/langchain_tool.py +82 -0
- autogen/interop/litellm/__init__.py +7 -0
- autogen/interop/litellm/litellm_config_factory.py +113 -0
- autogen/interop/pydantic_ai/__init__.py +7 -0
- autogen/interop/pydantic_ai/pydantic_ai.py +168 -0
- autogen/interop/registry.py +69 -0
- autogen/io/__init__.py +15 -0
- autogen/io/base.py +151 -0
- autogen/io/console.py +56 -0
- autogen/io/processors/__init__.py +12 -0
- autogen/io/processors/base.py +21 -0
- autogen/io/processors/console_event_processor.py +56 -0
- autogen/io/run_response.py +293 -0
- autogen/io/thread_io_stream.py +63 -0
- autogen/io/websockets.py +213 -0
- autogen/json_utils.py +43 -0
- autogen/llm_config.py +379 -0
- autogen/logger/__init__.py +11 -0
- autogen/logger/base_logger.py +128 -0
- autogen/logger/file_logger.py +261 -0
- autogen/logger/logger_factory.py +42 -0
- autogen/logger/logger_utils.py +57 -0
- autogen/logger/sqlite_logger.py +523 -0
- autogen/math_utils.py +339 -0
- autogen/mcp/__init__.py +7 -0
- autogen/mcp/mcp_client.py +208 -0
- autogen/messages/__init__.py +7 -0
- autogen/messages/agent_messages.py +948 -0
- autogen/messages/base_message.py +107 -0
- autogen/messages/client_messages.py +171 -0
- autogen/messages/print_message.py +49 -0
- autogen/oai/__init__.py +53 -0
- autogen/oai/anthropic.py +714 -0
- autogen/oai/bedrock.py +628 -0
- autogen/oai/cerebras.py +299 -0
- autogen/oai/client.py +1435 -0
- autogen/oai/client_utils.py +169 -0
- autogen/oai/cohere.py +479 -0
- autogen/oai/gemini.py +990 -0
- autogen/oai/gemini_types.py +129 -0
- autogen/oai/groq.py +305 -0
- autogen/oai/mistral.py +303 -0
- autogen/oai/oai_models/__init__.py +11 -0
- autogen/oai/oai_models/_models.py +16 -0
- autogen/oai/oai_models/chat_completion.py +87 -0
- autogen/oai/oai_models/chat_completion_audio.py +32 -0
- autogen/oai/oai_models/chat_completion_message.py +86 -0
- autogen/oai/oai_models/chat_completion_message_tool_call.py +37 -0
- autogen/oai/oai_models/chat_completion_token_logprob.py +63 -0
- autogen/oai/oai_models/completion_usage.py +60 -0
- autogen/oai/ollama.py +643 -0
- autogen/oai/openai_utils.py +881 -0
- autogen/oai/together.py +370 -0
- autogen/retrieve_utils.py +491 -0
- autogen/runtime_logging.py +160 -0
- autogen/token_count_utils.py +267 -0
- autogen/tools/__init__.py +20 -0
- autogen/tools/contrib/__init__.py +9 -0
- autogen/tools/contrib/time/__init__.py +7 -0
- autogen/tools/contrib/time/time.py +41 -0
- autogen/tools/dependency_injection.py +254 -0
- autogen/tools/experimental/__init__.py +43 -0
- autogen/tools/experimental/browser_use/__init__.py +7 -0
- autogen/tools/experimental/browser_use/browser_use.py +161 -0
- autogen/tools/experimental/crawl4ai/__init__.py +7 -0
- autogen/tools/experimental/crawl4ai/crawl4ai.py +153 -0
- autogen/tools/experimental/deep_research/__init__.py +7 -0
- autogen/tools/experimental/deep_research/deep_research.py +328 -0
- autogen/tools/experimental/duckduckgo/__init__.py +7 -0
- autogen/tools/experimental/duckduckgo/duckduckgo_search.py +109 -0
- autogen/tools/experimental/google/__init__.py +14 -0
- autogen/tools/experimental/google/authentication/__init__.py +11 -0
- autogen/tools/experimental/google/authentication/credentials_hosted_provider.py +43 -0
- autogen/tools/experimental/google/authentication/credentials_local_provider.py +91 -0
- autogen/tools/experimental/google/authentication/credentials_provider.py +35 -0
- autogen/tools/experimental/google/drive/__init__.py +9 -0
- autogen/tools/experimental/google/drive/drive_functions.py +124 -0
- autogen/tools/experimental/google/drive/toolkit.py +88 -0
- autogen/tools/experimental/google/model.py +17 -0
- autogen/tools/experimental/google/toolkit_protocol.py +19 -0
- autogen/tools/experimental/google_search/__init__.py +8 -0
- autogen/tools/experimental/google_search/google_search.py +93 -0
- autogen/tools/experimental/google_search/youtube_search.py +181 -0
- autogen/tools/experimental/messageplatform/__init__.py +17 -0
- autogen/tools/experimental/messageplatform/discord/__init__.py +7 -0
- autogen/tools/experimental/messageplatform/discord/discord.py +288 -0
- autogen/tools/experimental/messageplatform/slack/__init__.py +7 -0
- autogen/tools/experimental/messageplatform/slack/slack.py +391 -0
- autogen/tools/experimental/messageplatform/telegram/__init__.py +7 -0
- autogen/tools/experimental/messageplatform/telegram/telegram.py +275 -0
- autogen/tools/experimental/perplexity/__init__.py +7 -0
- autogen/tools/experimental/perplexity/perplexity_search.py +260 -0
- autogen/tools/experimental/tavily/__init__.py +7 -0
- autogen/tools/experimental/tavily/tavily_search.py +183 -0
- autogen/tools/experimental/web_search_preview/__init__.py +7 -0
- autogen/tools/experimental/web_search_preview/web_search_preview.py +114 -0
- autogen/tools/experimental/wikipedia/__init__.py +7 -0
- autogen/tools/experimental/wikipedia/wikipedia.py +287 -0
- autogen/tools/function_utils.py +411 -0
- autogen/tools/tool.py +187 -0
- autogen/tools/toolkit.py +86 -0
- autogen/types.py +29 -0
- autogen/version.py +7 -0
- ag2-0.9.1.dist-info/RECORD +0 -6
- ag2-0.9.1.dist-info/top_level.txt +0 -1
- {ag2-0.9.1.dist-info → ag2-0.9.1.post0.dist-info/licenses}/LICENSE +0 -0
- {ag2-0.9.1.dist-info → ag2-0.9.1.post0.dist-info/licenses}/NOTICE.md +0 -0
|
@@ -0,0 +1,928 @@
|
|
|
1
|
+
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
|
2
|
+
#
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
#
|
|
5
|
+
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
|
6
|
+
# SPDX-License-Identifier: MIT
|
|
7
|
+
import os
|
|
8
|
+
import re
|
|
9
|
+
import urllib.parse
|
|
10
|
+
from typing import Any, Callable, Optional, Union
|
|
11
|
+
|
|
12
|
+
from ....import_utils import optional_import_block, require_optional_import
|
|
13
|
+
from .base import Document, ItemID, QueryResults, VectorDB
|
|
14
|
+
from .utils import get_logger
|
|
15
|
+
|
|
16
|
+
with optional_import_block():
|
|
17
|
+
import numpy as np
|
|
18
|
+
import pgvector # noqa: F401
|
|
19
|
+
import psycopg
|
|
20
|
+
from pgvector.psycopg import register_vector
|
|
21
|
+
from sentence_transformers import SentenceTransformer
|
|
22
|
+
|
|
23
|
+
PGVECTOR_MAX_BATCH_SIZE = os.environ.get("PGVECTOR_MAX_BATCH_SIZE", 40000)
|
|
24
|
+
logger = get_logger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@require_optional_import(["psycopg", "sentence_transformers", "numpy"], "retrievechat-pgvector")
|
|
28
|
+
class Collection:
|
|
29
|
+
"""A Collection object for PGVector.
|
|
30
|
+
|
|
31
|
+
Attributes:
|
|
32
|
+
client: The PGVector client.
|
|
33
|
+
collection_name (str): The name of the collection. Default is "documents".
|
|
34
|
+
embedding_function (Callable): The embedding function used to generate the vector representation.
|
|
35
|
+
Default is None. SentenceTransformer("all-MiniLM-L6-v2").encode will be used when None.
|
|
36
|
+
Models can be chosen from:
|
|
37
|
+
https://huggingface.co/models?library=sentence-transformers
|
|
38
|
+
metadata (Optional[dict[str, Any]]): The metadata of the collection.
|
|
39
|
+
get_or_create (Optional): The flag indicating whether to get or create the collection.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
client: Optional[Any] = None,
|
|
45
|
+
collection_name: str = "ag2-docs",
|
|
46
|
+
embedding_function: Optional[Callable[..., Any]] = None,
|
|
47
|
+
metadata: Optional[Any] = None,
|
|
48
|
+
get_or_create: Optional[Any] = None,
|
|
49
|
+
):
|
|
50
|
+
"""Initialize the Collection object.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
client: The PostgreSQL client.
|
|
54
|
+
collection_name: The name of the collection. Default is "documents".
|
|
55
|
+
embedding_function: The embedding function used to generate the vector representation.
|
|
56
|
+
metadata: The metadata of the collection.
|
|
57
|
+
get_or_create: The flag indicating whether to get or create the collection.
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
None
|
|
61
|
+
"""
|
|
62
|
+
self.client = client
|
|
63
|
+
self.name = self.set_collection_name(collection_name)
|
|
64
|
+
self.require_embeddings_or_documents = False
|
|
65
|
+
self.ids = []
|
|
66
|
+
if embedding_function:
|
|
67
|
+
self.embedding_function = embedding_function
|
|
68
|
+
else:
|
|
69
|
+
self.embedding_function = SentenceTransformer("all-MiniLM-L6-v2").encode
|
|
70
|
+
self.metadata = metadata if metadata else {"hnsw:space": "ip", "hnsw:construction_ef": 32, "hnsw:M": 16}
|
|
71
|
+
self.documents = ""
|
|
72
|
+
self.get_or_create = get_or_create
|
|
73
|
+
# This will get the model dimension size by computing the embeddings dimensions
|
|
74
|
+
sentences = [
|
|
75
|
+
"The weather is lovely today in paradise.",
|
|
76
|
+
]
|
|
77
|
+
embeddings = self.embedding_function(sentences)
|
|
78
|
+
self.dimension = len(embeddings[0])
|
|
79
|
+
|
|
80
|
+
def set_collection_name(self, collection_name) -> str:
|
|
81
|
+
name = re.sub("-", "_", collection_name)
|
|
82
|
+
self.name = name
|
|
83
|
+
return self.name
|
|
84
|
+
|
|
85
|
+
def add(
|
|
86
|
+
self,
|
|
87
|
+
ids: list[ItemID],
|
|
88
|
+
documents: Optional[list[Document]],
|
|
89
|
+
embeddings: Optional[list[Any]] = None,
|
|
90
|
+
metadatas: Optional[list[Any]] = None,
|
|
91
|
+
) -> None:
|
|
92
|
+
"""Add documents to the collection.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
ids (List[ItemID]): A list of document IDs.
|
|
96
|
+
embeddings (List): A list of document embeddings. Optional
|
|
97
|
+
metadatas (List): A list of document metadatas. Optional
|
|
98
|
+
documents (List): A list of documents.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
None
|
|
102
|
+
"""
|
|
103
|
+
cursor = self.client.cursor()
|
|
104
|
+
sql_values = []
|
|
105
|
+
if embeddings is not None and metadatas is not None:
|
|
106
|
+
for doc_id, embedding, metadata, document in zip(ids, embeddings, metadatas, documents):
|
|
107
|
+
metadata = re.sub("'", '"', str(metadata))
|
|
108
|
+
sql_values.append((doc_id, embedding, metadata, document))
|
|
109
|
+
sql_string = f"INSERT INTO {self.name} (id, embedding, metadatas, documents)\nVALUES (%s, %s, %s, %s);\n"
|
|
110
|
+
elif embeddings is not None:
|
|
111
|
+
for doc_id, embedding, document in zip(ids, embeddings, documents):
|
|
112
|
+
sql_values.append((doc_id, embedding, document))
|
|
113
|
+
sql_string = f"INSERT INTO {self.name} (id, embedding, documents) VALUES (%s, %s, %s);\n"
|
|
114
|
+
elif metadatas is not None:
|
|
115
|
+
for doc_id, metadata, document in zip(ids, metadatas, documents):
|
|
116
|
+
metadata = re.sub("'", '"', str(metadata))
|
|
117
|
+
embedding = self.embedding_function(document)
|
|
118
|
+
sql_values.append((doc_id, metadata, embedding, document))
|
|
119
|
+
sql_string = f"INSERT INTO {self.name} (id, metadatas, embedding, documents)\nVALUES (%s, %s, %s, %s);\n"
|
|
120
|
+
else:
|
|
121
|
+
for doc_id, document in zip(ids, documents):
|
|
122
|
+
embedding = self.embedding_function(document)
|
|
123
|
+
sql_values.append((doc_id, document, embedding))
|
|
124
|
+
sql_string = f"INSERT INTO {self.name} (id, documents, embedding)\nVALUES (%s, %s, %s);\n"
|
|
125
|
+
logger.debug(f"Add SQL String:\n{sql_string}\n{sql_values}")
|
|
126
|
+
cursor.executemany(sql_string, sql_values)
|
|
127
|
+
cursor.close()
|
|
128
|
+
|
|
129
|
+
def upsert(
|
|
130
|
+
self,
|
|
131
|
+
ids: list[ItemID],
|
|
132
|
+
documents: list[Document],
|
|
133
|
+
embeddings: Optional[list[Any]] = None,
|
|
134
|
+
metadatas: Optional[list[Any]] = None,
|
|
135
|
+
) -> None:
|
|
136
|
+
"""Upsert documents into the collection.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
ids (List[ItemID]): A list of document IDs.
|
|
140
|
+
documents (List): A list of documents.
|
|
141
|
+
embeddings (List): A list of document embeddings.
|
|
142
|
+
metadatas (List): A list of document metadatas.
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
None
|
|
146
|
+
"""
|
|
147
|
+
cursor = self.client.cursor()
|
|
148
|
+
sql_values = []
|
|
149
|
+
if embeddings is not None and metadatas is not None:
|
|
150
|
+
for doc_id, embedding, metadata, document in zip(ids, embeddings, metadatas, documents):
|
|
151
|
+
metadata = re.sub("'", '"', str(metadata))
|
|
152
|
+
sql_values.append((doc_id, embedding, metadata, document, embedding, metadata, document))
|
|
153
|
+
sql_string = (
|
|
154
|
+
f"INSERT INTO {self.name} (id, embedding, metadatas, documents)\n"
|
|
155
|
+
f"VALUES (%s, %s, %s, %s)\n"
|
|
156
|
+
f"ON CONFLICT (id)\n"
|
|
157
|
+
f"DO UPDATE SET embedding = %s,\n"
|
|
158
|
+
f"metadatas = %s, documents = %s;\n"
|
|
159
|
+
)
|
|
160
|
+
elif embeddings is not None:
|
|
161
|
+
for doc_id, embedding, document in zip(ids, embeddings, documents):
|
|
162
|
+
sql_values.append((doc_id, embedding, document, embedding, document))
|
|
163
|
+
sql_string = (
|
|
164
|
+
f"INSERT INTO {self.name} (id, embedding, documents) "
|
|
165
|
+
f"VALUES (%s, %s, %s) ON CONFLICT (id)\n"
|
|
166
|
+
f"DO UPDATE SET embedding = %s, documents = %s;\n"
|
|
167
|
+
)
|
|
168
|
+
elif metadatas is not None:
|
|
169
|
+
for doc_id, metadata, document in zip(ids, metadatas, documents):
|
|
170
|
+
metadata = re.sub("'", '"', str(metadata))
|
|
171
|
+
embedding = self.embedding_function(document)
|
|
172
|
+
sql_values.append((doc_id, metadata, embedding, document, metadata, document, embedding))
|
|
173
|
+
sql_string = (
|
|
174
|
+
f"INSERT INTO {self.name} (id, metadatas, embedding, documents)\n"
|
|
175
|
+
f"VALUES (%s, %s, %s, %s)\n"
|
|
176
|
+
f"ON CONFLICT (id)\n"
|
|
177
|
+
f"DO UPDATE SET metadatas = %s, documents = %s, embedding = %s;\n"
|
|
178
|
+
)
|
|
179
|
+
else:
|
|
180
|
+
for doc_id, document in zip(ids, documents):
|
|
181
|
+
embedding = self.embedding_function(document)
|
|
182
|
+
sql_values.append((doc_id, document, embedding, document))
|
|
183
|
+
sql_string = (
|
|
184
|
+
f"INSERT INTO {self.name} (id, documents, embedding)\n"
|
|
185
|
+
f"VALUES (%s, %s, %s)\n"
|
|
186
|
+
f"ON CONFLICT (id)\n"
|
|
187
|
+
f"DO UPDATE SET documents = %s;\n"
|
|
188
|
+
)
|
|
189
|
+
logger.debug(f"Upsert SQL String:\n{sql_string}\n{sql_values}")
|
|
190
|
+
cursor.executemany(sql_string, sql_values)
|
|
191
|
+
cursor.close()
|
|
192
|
+
|
|
193
|
+
def count(self) -> int:
|
|
194
|
+
"""Get the total number of documents in the collection.
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
int: The total number of documents.
|
|
198
|
+
"""
|
|
199
|
+
cursor = self.client.cursor()
|
|
200
|
+
query = f"SELECT COUNT(*) FROM {self.name}"
|
|
201
|
+
cursor.execute(query)
|
|
202
|
+
total = cursor.fetchone()[0]
|
|
203
|
+
cursor.close()
|
|
204
|
+
try:
|
|
205
|
+
total = int(total)
|
|
206
|
+
except (TypeError, ValueError):
|
|
207
|
+
total = None
|
|
208
|
+
return total
|
|
209
|
+
|
|
210
|
+
def table_exists(self, table_name: str) -> bool:
|
|
211
|
+
"""Check if a table exists in the PostgreSQL database.
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
table_name (str): The name of the table to check.
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
bool: True if the table exists, False otherwise.
|
|
218
|
+
"""
|
|
219
|
+
cursor = self.client.cursor()
|
|
220
|
+
cursor.execute(
|
|
221
|
+
"""
|
|
222
|
+
SELECT EXISTS (
|
|
223
|
+
SELECT 1
|
|
224
|
+
FROM information_schema.tables
|
|
225
|
+
WHERE table_name = %s
|
|
226
|
+
)
|
|
227
|
+
""",
|
|
228
|
+
(table_name,),
|
|
229
|
+
)
|
|
230
|
+
exists = cursor.fetchone()[0]
|
|
231
|
+
return exists
|
|
232
|
+
|
|
233
|
+
def get(
|
|
234
|
+
self,
|
|
235
|
+
ids: Optional[str] = None,
|
|
236
|
+
include: Optional[str] = None,
|
|
237
|
+
where: Optional[str] = None,
|
|
238
|
+
limit: Optional[Union[int, str]] = None,
|
|
239
|
+
offset: Optional[Union[int, str]] = None,
|
|
240
|
+
) -> list[Document]:
|
|
241
|
+
"""Retrieve documents from the collection.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
ids (Optional[List]): A list of document IDs.
|
|
245
|
+
include (Optional): The fields to include.
|
|
246
|
+
where (Optional): Additional filtering criteria.
|
|
247
|
+
limit (Optional): The maximum number of documents to retrieve.
|
|
248
|
+
offset (Optional): The offset for pagination.
|
|
249
|
+
|
|
250
|
+
Returns:
|
|
251
|
+
List: The retrieved documents.
|
|
252
|
+
"""
|
|
253
|
+
cursor = self.client.cursor()
|
|
254
|
+
|
|
255
|
+
# Initialize variables for query components
|
|
256
|
+
select_clause = "SELECT id, metadatas, documents, embedding"
|
|
257
|
+
from_clause = f"FROM {self.name}"
|
|
258
|
+
where_clause = ""
|
|
259
|
+
limit_clause = ""
|
|
260
|
+
offset_clause = ""
|
|
261
|
+
|
|
262
|
+
# Handle include clause
|
|
263
|
+
if include:
|
|
264
|
+
select_clause = f"SELECT id, {', '.join(include)}, embedding"
|
|
265
|
+
|
|
266
|
+
# Handle where clause
|
|
267
|
+
if ids:
|
|
268
|
+
where_clause = f"WHERE id IN ({', '.join(['%s' for _ in ids])})"
|
|
269
|
+
elif where:
|
|
270
|
+
where_clause = f"WHERE {where}"
|
|
271
|
+
|
|
272
|
+
# Handle limit and offset clauses
|
|
273
|
+
if limit:
|
|
274
|
+
limit_clause = "LIMIT %s"
|
|
275
|
+
if offset:
|
|
276
|
+
offset_clause = "OFFSET %s"
|
|
277
|
+
|
|
278
|
+
# Construct the full query
|
|
279
|
+
query = f"{select_clause} {from_clause} {where_clause} {limit_clause} {offset_clause}"
|
|
280
|
+
retrieved_documents = []
|
|
281
|
+
try:
|
|
282
|
+
# Execute the query with the appropriate values
|
|
283
|
+
if ids is not None:
|
|
284
|
+
cursor.execute(query, ids)
|
|
285
|
+
else:
|
|
286
|
+
query_params = []
|
|
287
|
+
if limit:
|
|
288
|
+
query_params.append(limit)
|
|
289
|
+
if offset:
|
|
290
|
+
query_params.append(offset)
|
|
291
|
+
cursor.execute(query, query_params)
|
|
292
|
+
|
|
293
|
+
retrieval = cursor.fetchall()
|
|
294
|
+
for retrieved_document in retrieval:
|
|
295
|
+
retrieved_documents.append(
|
|
296
|
+
Document(
|
|
297
|
+
id=retrieved_document[0].strip(),
|
|
298
|
+
metadata=retrieved_document[1],
|
|
299
|
+
content=retrieved_document[2],
|
|
300
|
+
embedding=retrieved_document[3],
|
|
301
|
+
)
|
|
302
|
+
)
|
|
303
|
+
except (psycopg.errors.UndefinedTable, psycopg.errors.UndefinedColumn) as e:
|
|
304
|
+
logger.info(f"Error executing select on non-existent table: {self.name}. Creating it instead. Error: {e}")
|
|
305
|
+
self.create_collection(collection_name=self.name, dimension=self.dimension)
|
|
306
|
+
logger.info(f"Created table {self.name}")
|
|
307
|
+
|
|
308
|
+
cursor.close()
|
|
309
|
+
return retrieved_documents
|
|
310
|
+
|
|
311
|
+
def update(self, ids: list[str], embeddings: list[Any], metadatas: list[Any], documents: list[Document]) -> None:
|
|
312
|
+
"""Update documents in the collection.
|
|
313
|
+
|
|
314
|
+
Args:
|
|
315
|
+
ids (List): A list of document IDs.
|
|
316
|
+
embeddings (List): A list of document embeddings.
|
|
317
|
+
metadatas (List): A list of document metadatas.
|
|
318
|
+
documents (List): A list of documents.
|
|
319
|
+
|
|
320
|
+
Returns:
|
|
321
|
+
None
|
|
322
|
+
"""
|
|
323
|
+
cursor = self.client.cursor()
|
|
324
|
+
sql_values = []
|
|
325
|
+
for doc_id, embedding, metadata, document in zip(ids, embeddings, metadatas, documents):
|
|
326
|
+
sql_values.append((doc_id, embedding, metadata, document, doc_id, embedding, metadata, document))
|
|
327
|
+
sql_string = (
|
|
328
|
+
f"INSERT INTO {self.name} (id, embedding, metadata, document) "
|
|
329
|
+
f"VALUES (%s, %s, %s, %s) "
|
|
330
|
+
f"ON CONFLICT (id) "
|
|
331
|
+
f"DO UPDATE SET id = %s, embedding = %s, "
|
|
332
|
+
f"metadata = %s, document = %s;\n"
|
|
333
|
+
)
|
|
334
|
+
logger.debug(f"Upsert SQL String:\n{sql_string}\n")
|
|
335
|
+
cursor.executemany(sql_string, sql_values)
|
|
336
|
+
cursor.close()
|
|
337
|
+
|
|
338
|
+
@staticmethod
|
|
339
|
+
def euclidean_distance(arr1: list[float], arr2: list[float]) -> float:
|
|
340
|
+
"""Calculate the Euclidean distance between two vectors.
|
|
341
|
+
|
|
342
|
+
Parameters:
|
|
343
|
+
- arr1 (List[float]): The first vector.
|
|
344
|
+
- arr2 (List[float]): The second vector.
|
|
345
|
+
|
|
346
|
+
Returns:
|
|
347
|
+
- float: The Euclidean distance between arr1 and arr2.
|
|
348
|
+
"""
|
|
349
|
+
dist = np.linalg.norm(arr1 - arr2)
|
|
350
|
+
return dist
|
|
351
|
+
|
|
352
|
+
@staticmethod
|
|
353
|
+
def cosine_distance(arr1: list[float], arr2: list[float]) -> float:
|
|
354
|
+
"""Calculate the cosine distance between two vectors.
|
|
355
|
+
|
|
356
|
+
Parameters:
|
|
357
|
+
- arr1 (List[float]): The first vector.
|
|
358
|
+
- arr2 (List[float]): The second vector.
|
|
359
|
+
|
|
360
|
+
Returns:
|
|
361
|
+
- float: The cosine distance between arr1 and arr2.
|
|
362
|
+
"""
|
|
363
|
+
dist = np.dot(arr1, arr2) / (np.linalg.norm(arr1) * np.linalg.norm(arr2))
|
|
364
|
+
return dist
|
|
365
|
+
|
|
366
|
+
@staticmethod
|
|
367
|
+
def inner_product_distance(arr1: list[float], arr2: list[float]) -> float:
|
|
368
|
+
"""Calculate the Euclidean distance between two vectors.
|
|
369
|
+
|
|
370
|
+
Parameters:
|
|
371
|
+
- arr1 (List[float]): The first vector.
|
|
372
|
+
- arr2 (List[float]): The second vector.
|
|
373
|
+
|
|
374
|
+
Returns:
|
|
375
|
+
- float: The Euclidean distance between arr1 and arr2.
|
|
376
|
+
"""
|
|
377
|
+
dist = np.linalg.norm(arr1 - arr2)
|
|
378
|
+
return dist
|
|
379
|
+
|
|
380
|
+
def query(
|
|
381
|
+
self,
|
|
382
|
+
query_texts: list[str],
|
|
383
|
+
collection_name: Optional[str] = None,
|
|
384
|
+
n_results: Optional[int] = 10,
|
|
385
|
+
distance_type: Optional[str] = "euclidean",
|
|
386
|
+
distance_threshold: Optional[float] = -1,
|
|
387
|
+
include_embedding: Optional[bool] = False,
|
|
388
|
+
) -> QueryResults:
|
|
389
|
+
"""Query documents in the collection.
|
|
390
|
+
|
|
391
|
+
Args:
|
|
392
|
+
query_texts (List[str]): A list of query texts.
|
|
393
|
+
collection_name (Optional[str]): The name of the collection.
|
|
394
|
+
n_results (int): The maximum number of results to return.
|
|
395
|
+
distance_type (Optional[str]): Distance search type - euclidean or cosine
|
|
396
|
+
distance_threshold (Optional[float]): Distance threshold to limit searches
|
|
397
|
+
include_embedding (Optional[bool]): Include embedding values in QueryResults
|
|
398
|
+
Returns:
|
|
399
|
+
QueryResults: The query results.
|
|
400
|
+
"""
|
|
401
|
+
if collection_name:
|
|
402
|
+
self.name = collection_name
|
|
403
|
+
|
|
404
|
+
clause = "ORDER BY"
|
|
405
|
+
if distance_threshold == -1:
|
|
406
|
+
distance_threshold = ""
|
|
407
|
+
clause = "ORDER BY"
|
|
408
|
+
elif distance_threshold > 0:
|
|
409
|
+
distance_threshold = f"< {distance_threshold}"
|
|
410
|
+
clause = "WHERE"
|
|
411
|
+
|
|
412
|
+
cursor = self.client.cursor()
|
|
413
|
+
results = []
|
|
414
|
+
for query_text in query_texts:
|
|
415
|
+
vector = self.embedding_function(query_text, convert_to_tensor=False).tolist()
|
|
416
|
+
if distance_type.lower() == "cosine":
|
|
417
|
+
index_function = "<=>"
|
|
418
|
+
elif distance_type.lower() == "euclidean":
|
|
419
|
+
index_function = "<->"
|
|
420
|
+
elif distance_type.lower() == "inner-product":
|
|
421
|
+
index_function = "<#>"
|
|
422
|
+
else:
|
|
423
|
+
index_function = "<->"
|
|
424
|
+
query = (
|
|
425
|
+
f"SELECT id, documents, embedding, metadatas "
|
|
426
|
+
f"FROM {self.name} "
|
|
427
|
+
f"{clause} embedding {index_function} '{vector!s}' {distance_threshold} "
|
|
428
|
+
f"LIMIT {n_results}"
|
|
429
|
+
)
|
|
430
|
+
cursor.execute(query)
|
|
431
|
+
result = []
|
|
432
|
+
for row in cursor.fetchall():
|
|
433
|
+
fetched_document = Document(id=row[0].strip(), content=row[1], embedding=row[2], metadata=row[3])
|
|
434
|
+
fetched_document_array = self.convert_string_to_array(array_string=fetched_document.get("embedding"))
|
|
435
|
+
if distance_type.lower() == "cosine":
|
|
436
|
+
distance = self.cosine_distance(fetched_document_array, vector)
|
|
437
|
+
elif distance_type.lower() == "euclidean":
|
|
438
|
+
distance = self.euclidean_distance(fetched_document_array, vector)
|
|
439
|
+
elif distance_type.lower() == "inner-product":
|
|
440
|
+
distance = self.inner_product_distance(fetched_document_array, vector)
|
|
441
|
+
else:
|
|
442
|
+
distance = self.euclidean_distance(fetched_document_array, vector)
|
|
443
|
+
if not include_embedding:
|
|
444
|
+
fetched_document = Document(id=row[0].strip(), content=row[1], metadata=row[3])
|
|
445
|
+
result.append((fetched_document, distance))
|
|
446
|
+
results.append(result)
|
|
447
|
+
cursor.close()
|
|
448
|
+
logger.debug(f"Query Results: {results}")
|
|
449
|
+
return results
|
|
450
|
+
|
|
451
|
+
@staticmethod
|
|
452
|
+
def convert_string_to_array(array_string: str) -> list[float]:
|
|
453
|
+
"""Convert a string representation of an array to a list of floats.
|
|
454
|
+
|
|
455
|
+
Parameters:
|
|
456
|
+
- array_string (str): The string representation of the array.
|
|
457
|
+
|
|
458
|
+
Returns:
|
|
459
|
+
- list: A list of floats parsed from the input string. If the input is
|
|
460
|
+
not a string, it returns the input itself.
|
|
461
|
+
"""
|
|
462
|
+
if not isinstance(array_string, str):
|
|
463
|
+
return array_string
|
|
464
|
+
array_string = array_string.strip("[]")
|
|
465
|
+
array = [float(num) for num in array_string.split()]
|
|
466
|
+
return array
|
|
467
|
+
|
|
468
|
+
def modify(self, metadata, collection_name: Optional[str] = None) -> None:
|
|
469
|
+
"""Modify metadata for the collection.
|
|
470
|
+
|
|
471
|
+
Args:
|
|
472
|
+
collection_name: The name of the collection.
|
|
473
|
+
metadata: The new metadata.
|
|
474
|
+
|
|
475
|
+
Returns:
|
|
476
|
+
None
|
|
477
|
+
"""
|
|
478
|
+
if collection_name:
|
|
479
|
+
self.name = collection_name
|
|
480
|
+
cursor = self.client.cursor()
|
|
481
|
+
cursor.execute("UPDATE collectionsSET metadata = '%s'WHERE collection_name = '%s';", (metadata, self.name))
|
|
482
|
+
cursor.close()
|
|
483
|
+
|
|
484
|
+
def delete(self, ids: list[ItemID], collection_name: Optional[str] = None) -> None:
|
|
485
|
+
"""Delete documents from the collection.
|
|
486
|
+
|
|
487
|
+
Args:
|
|
488
|
+
ids (List[ItemID]): A list of document IDs to delete.
|
|
489
|
+
collection_name (str): The name of the collection to delete.
|
|
490
|
+
|
|
491
|
+
Returns:
|
|
492
|
+
None
|
|
493
|
+
"""
|
|
494
|
+
if collection_name:
|
|
495
|
+
self.name = collection_name
|
|
496
|
+
cursor = self.client.cursor()
|
|
497
|
+
id_placeholders = ", ".join(["%s" for _ in ids])
|
|
498
|
+
cursor.execute(f"DELETE FROM {self.name} WHERE id IN ({id_placeholders});", ids)
|
|
499
|
+
cursor.close()
|
|
500
|
+
|
|
501
|
+
def delete_collection(self, collection_name: Optional[str] = None) -> None:
|
|
502
|
+
"""Delete the entire collection.
|
|
503
|
+
|
|
504
|
+
Args:
|
|
505
|
+
collection_name (Optional[str]): The name of the collection to delete.
|
|
506
|
+
|
|
507
|
+
Returns:
|
|
508
|
+
None
|
|
509
|
+
"""
|
|
510
|
+
if collection_name:
|
|
511
|
+
self.name = collection_name
|
|
512
|
+
cursor = self.client.cursor()
|
|
513
|
+
cursor.execute(f"DROP TABLE IF EXISTS {self.name}")
|
|
514
|
+
cursor.close()
|
|
515
|
+
|
|
516
|
+
def create_collection(
|
|
517
|
+
self, collection_name: Optional[str] = None, dimension: Optional[Union[str, int]] = None
|
|
518
|
+
) -> None:
|
|
519
|
+
"""Create a new collection.
|
|
520
|
+
|
|
521
|
+
Args:
|
|
522
|
+
collection_name (Optional[str]): The name of the new collection.
|
|
523
|
+
dimension (Optional[Union[str, int]]): The dimension size of the sentence embedding model
|
|
524
|
+
|
|
525
|
+
Returns:
|
|
526
|
+
None
|
|
527
|
+
"""
|
|
528
|
+
if collection_name:
|
|
529
|
+
self.name = collection_name
|
|
530
|
+
|
|
531
|
+
if dimension:
|
|
532
|
+
self.dimension = dimension
|
|
533
|
+
elif self.dimension is None:
|
|
534
|
+
self.dimension = 384
|
|
535
|
+
|
|
536
|
+
cursor = self.client.cursor()
|
|
537
|
+
cursor.execute(
|
|
538
|
+
f"CREATE TABLE {self.name} ("
|
|
539
|
+
f"documents text, id CHAR(8) PRIMARY KEY, metadatas JSONB, embedding vector({self.dimension}));"
|
|
540
|
+
f"CREATE INDEX "
|
|
541
|
+
f"ON {self.name} USING hnsw (embedding vector_l2_ops) WITH (m = {self.metadata['hnsw:M']}, "
|
|
542
|
+
f"ef_construction = {self.metadata['hnsw:construction_ef']});"
|
|
543
|
+
f"CREATE INDEX "
|
|
544
|
+
f"ON {self.name} USING hnsw (embedding vector_cosine_ops) WITH (m = {self.metadata['hnsw:M']}, "
|
|
545
|
+
f"ef_construction = {self.metadata['hnsw:construction_ef']});"
|
|
546
|
+
f"CREATE INDEX "
|
|
547
|
+
f"ON {self.name} USING hnsw (embedding vector_ip_ops) WITH (m = {self.metadata['hnsw:M']}, "
|
|
548
|
+
f"ef_construction = {self.metadata['hnsw:construction_ef']});"
|
|
549
|
+
)
|
|
550
|
+
cursor.close()
|
|
551
|
+
|
|
552
|
+
|
|
553
|
+
@require_optional_import(["pgvector", "psycopg", "sentence_transformers"], "retrievechat-pgvector")
|
|
554
|
+
class PGVectorDB(VectorDB):
|
|
555
|
+
"""A vector database that uses PGVector as the backend."""
|
|
556
|
+
|
|
557
|
+
def __init__(
|
|
558
|
+
self,
|
|
559
|
+
*,
|
|
560
|
+
conn: Optional["psycopg.Connection"] = None,
|
|
561
|
+
connection_string: Optional[str] = None,
|
|
562
|
+
host: Optional[str] = None,
|
|
563
|
+
port: Optional[Union[int, str]] = None,
|
|
564
|
+
dbname: Optional[str] = None,
|
|
565
|
+
username: Optional[str] = None,
|
|
566
|
+
password: Optional[str] = None,
|
|
567
|
+
connect_timeout: Optional[int] = 10,
|
|
568
|
+
embedding_function: Callable = None,
|
|
569
|
+
metadata: Optional[dict[str, Any]] = None,
|
|
570
|
+
) -> None:
|
|
571
|
+
"""Initialize the vector database.
|
|
572
|
+
|
|
573
|
+
Note: connection_string or host + port + dbname must be specified
|
|
574
|
+
|
|
575
|
+
Args:
|
|
576
|
+
conn: psycopg.Connection | A customer connection object to connect to the database.
|
|
577
|
+
A connection object may include additional key/values:
|
|
578
|
+
https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING
|
|
579
|
+
connection_string: "postgresql://username:password@hostname:port/database" | The PGVector connection string. Default is None.
|
|
580
|
+
host: str | The host to connect to. Default is None.
|
|
581
|
+
port: int | The port to connect to. Default is None.
|
|
582
|
+
dbname: str | The database name to connect to. Default is None.
|
|
583
|
+
username: str | The database username to use. Default is None.
|
|
584
|
+
password: str | The database user password to use. Default is None.
|
|
585
|
+
connect_timeout: int | The timeout to set for the connection. Default is 10.
|
|
586
|
+
embedding_function: Callable | The embedding function used to generate the vector representation.
|
|
587
|
+
Default is None. SentenceTransformer("all-MiniLM-L6-v2").encode will be used when None.
|
|
588
|
+
Models can be chosen from:
|
|
589
|
+
https://huggingface.co/models?library=sentence-transformers
|
|
590
|
+
metadata: dict | The metadata of the vector database. Default is None. If None, it will use this
|
|
591
|
+
setting: `{"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 16}`. Creates Index on table
|
|
592
|
+
using hnsw (embedding vector_l2_ops) WITH (m = hnsw:M) ef_construction = "hnsw:construction_ef".
|
|
593
|
+
For more info: https://github.com/pgvector/pgvector?tab=readme-ov-file#hnsw
|
|
594
|
+
Returns:
|
|
595
|
+
None
|
|
596
|
+
"""
|
|
597
|
+
self.client = self.establish_connection(
|
|
598
|
+
conn=conn,
|
|
599
|
+
connection_string=connection_string,
|
|
600
|
+
host=host,
|
|
601
|
+
port=port,
|
|
602
|
+
dbname=dbname,
|
|
603
|
+
username=username,
|
|
604
|
+
password=password,
|
|
605
|
+
connect_timeout=connect_timeout,
|
|
606
|
+
)
|
|
607
|
+
if embedding_function:
|
|
608
|
+
self.embedding_function = embedding_function
|
|
609
|
+
else:
|
|
610
|
+
self.embedding_function = SentenceTransformer("all-MiniLM-L6-v2").encode
|
|
611
|
+
self.metadata = metadata
|
|
612
|
+
register_vector(self.client)
|
|
613
|
+
self.active_collection = None
|
|
614
|
+
|
|
615
|
+
def establish_connection(
|
|
616
|
+
self,
|
|
617
|
+
conn: Optional["psycopg.Connection"] = None,
|
|
618
|
+
connection_string: Optional[str] = None,
|
|
619
|
+
host: Optional[str] = None,
|
|
620
|
+
port: Optional[Union[int, str]] = None,
|
|
621
|
+
dbname: Optional[str] = None,
|
|
622
|
+
username: Optional[str] = None,
|
|
623
|
+
password: Optional[str] = None,
|
|
624
|
+
connect_timeout: Optional[int] = 10,
|
|
625
|
+
) -> "psycopg.Connection":
|
|
626
|
+
"""Establishes a connection to a PostgreSQL database using psycopg.
|
|
627
|
+
|
|
628
|
+
Args:
|
|
629
|
+
conn: An existing psycopg connection object. If provided, this connection will be used.
|
|
630
|
+
connection_string: A string containing the connection information. If provided, a new connection will be established using this string.
|
|
631
|
+
host: The hostname of the PostgreSQL server. Used if connection_string is not provided.
|
|
632
|
+
port: The port number to connect to at the server host. Used if connection_string is not provided.
|
|
633
|
+
dbname: The database name. Used if connection_string is not provided.
|
|
634
|
+
username: The username to connect as. Used if connection_string is not provided.
|
|
635
|
+
password: The user's password. Used if connection_string is not provided.
|
|
636
|
+
connect_timeout: Maximum wait for connection, in seconds. The default is 10 seconds.
|
|
637
|
+
|
|
638
|
+
Returns:
|
|
639
|
+
A psycopg.Connection object representing the established connection.
|
|
640
|
+
|
|
641
|
+
Raises:
|
|
642
|
+
PermissionError if no credentials are supplied
|
|
643
|
+
psycopg.Error: If an error occurs while trying to connect to the database.
|
|
644
|
+
"""
|
|
645
|
+
try:
|
|
646
|
+
if conn:
|
|
647
|
+
self.client = conn
|
|
648
|
+
elif connection_string:
|
|
649
|
+
parsed_connection = urllib.parse.urlparse(connection_string)
|
|
650
|
+
encoded_username = urllib.parse.quote(parsed_connection.username, safe="")
|
|
651
|
+
encoded_password = urllib.parse.quote(parsed_connection.password, safe="")
|
|
652
|
+
encoded_password = f":{encoded_password}@"
|
|
653
|
+
encoded_host = urllib.parse.quote(parsed_connection.hostname, safe="")
|
|
654
|
+
encoded_port = f":{parsed_connection.port}"
|
|
655
|
+
encoded_database = urllib.parse.quote(parsed_connection.path[1:], safe="")
|
|
656
|
+
connection_string_encoded = (
|
|
657
|
+
f"{parsed_connection.scheme}://{encoded_username}{encoded_password}"
|
|
658
|
+
f"{encoded_host}{encoded_port}/{encoded_database}"
|
|
659
|
+
)
|
|
660
|
+
self.client = psycopg.connect(conninfo=connection_string_encoded, autocommit=True)
|
|
661
|
+
elif host:
|
|
662
|
+
connection_string = ""
|
|
663
|
+
if host:
|
|
664
|
+
encoded_host = urllib.parse.quote(host, safe="")
|
|
665
|
+
connection_string += f"host={encoded_host} "
|
|
666
|
+
if port:
|
|
667
|
+
connection_string += f"port={port} "
|
|
668
|
+
if dbname:
|
|
669
|
+
encoded_database = urllib.parse.quote(dbname, safe="")
|
|
670
|
+
connection_string += f"dbname={encoded_database} "
|
|
671
|
+
if username:
|
|
672
|
+
encoded_username = urllib.parse.quote(username, safe="")
|
|
673
|
+
connection_string += f"user={encoded_username} "
|
|
674
|
+
if password:
|
|
675
|
+
encoded_password = urllib.parse.quote(password, safe="")
|
|
676
|
+
connection_string += f"password={encoded_password} "
|
|
677
|
+
|
|
678
|
+
self.client = psycopg.connect(
|
|
679
|
+
conninfo=connection_string,
|
|
680
|
+
connect_timeout=connect_timeout,
|
|
681
|
+
autocommit=True,
|
|
682
|
+
)
|
|
683
|
+
else:
|
|
684
|
+
logger.error("Credentials were not supplied...")
|
|
685
|
+
raise PermissionError
|
|
686
|
+
self.client.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
|
687
|
+
except psycopg.Error as e:
|
|
688
|
+
logger.error("Error connecting to the database: ", e)
|
|
689
|
+
raise e
|
|
690
|
+
return self.client
|
|
691
|
+
|
|
692
|
+
def create_collection(
|
|
693
|
+
self, collection_name: str, overwrite: bool = False, get_or_create: bool = True
|
|
694
|
+
) -> Collection:
|
|
695
|
+
"""Create a collection in the vector database.
|
|
696
|
+
Case 1. if the collection does not exist, create the collection.
|
|
697
|
+
Case 2. the collection exists, if overwrite is True, it will overwrite the collection.
|
|
698
|
+
Case 3. the collection exists and overwrite is False, if get_or_create is True, it will get the collection,
|
|
699
|
+
otherwise it raise a ValueError.
|
|
700
|
+
|
|
701
|
+
Args:
|
|
702
|
+
collection_name: str | The name of the collection.
|
|
703
|
+
overwrite: bool | Whether to overwrite the collection if it exists. Default is False.
|
|
704
|
+
get_or_create: bool | Whether to get the collection if it exists. Default is True.
|
|
705
|
+
|
|
706
|
+
Returns:
|
|
707
|
+
Collection | The collection object.
|
|
708
|
+
"""
|
|
709
|
+
try:
|
|
710
|
+
if self.active_collection and self.active_collection.name == collection_name:
|
|
711
|
+
collection = self.active_collection
|
|
712
|
+
else:
|
|
713
|
+
collection = self.get_collection(collection_name)
|
|
714
|
+
except ValueError:
|
|
715
|
+
collection = None
|
|
716
|
+
if collection is None:
|
|
717
|
+
collection = Collection(
|
|
718
|
+
client=self.client,
|
|
719
|
+
collection_name=collection_name,
|
|
720
|
+
embedding_function=self.embedding_function,
|
|
721
|
+
get_or_create=get_or_create,
|
|
722
|
+
metadata=self.metadata,
|
|
723
|
+
)
|
|
724
|
+
collection.set_collection_name(collection_name=collection_name)
|
|
725
|
+
collection.create_collection(collection_name=collection_name)
|
|
726
|
+
return collection
|
|
727
|
+
elif overwrite:
|
|
728
|
+
self.delete_collection(collection_name)
|
|
729
|
+
collection = Collection(
|
|
730
|
+
client=self.client,
|
|
731
|
+
collection_name=collection_name,
|
|
732
|
+
embedding_function=self.embedding_function,
|
|
733
|
+
get_or_create=get_or_create,
|
|
734
|
+
metadata=self.metadata,
|
|
735
|
+
)
|
|
736
|
+
collection.set_collection_name(collection_name=collection_name)
|
|
737
|
+
collection.create_collection(collection_name=collection_name)
|
|
738
|
+
return collection
|
|
739
|
+
elif get_or_create:
|
|
740
|
+
return collection
|
|
741
|
+
elif not collection.table_exists(table_name=collection_name):
|
|
742
|
+
collection = Collection(
|
|
743
|
+
client=self.client,
|
|
744
|
+
collection_name=collection_name,
|
|
745
|
+
embedding_function=self.embedding_function,
|
|
746
|
+
get_or_create=get_or_create,
|
|
747
|
+
metadata=self.metadata,
|
|
748
|
+
)
|
|
749
|
+
collection.set_collection_name(collection_name=collection_name)
|
|
750
|
+
collection.create_collection(collection_name=collection_name)
|
|
751
|
+
return collection
|
|
752
|
+
else:
|
|
753
|
+
raise ValueError(f"Collection {collection_name} already exists.")
|
|
754
|
+
|
|
755
|
+
def get_collection(self, collection_name: str = None) -> Collection:
|
|
756
|
+
"""Get the collection from the vector database.
|
|
757
|
+
|
|
758
|
+
Args:
|
|
759
|
+
collection_name: str | The name of the collection. Default is None. If None, return the
|
|
760
|
+
current active collection.
|
|
761
|
+
|
|
762
|
+
Returns:
|
|
763
|
+
Collection | The collection object.
|
|
764
|
+
"""
|
|
765
|
+
if collection_name is None:
|
|
766
|
+
if self.active_collection is None:
|
|
767
|
+
raise ValueError("No collection is specified.")
|
|
768
|
+
else:
|
|
769
|
+
logger.debug(
|
|
770
|
+
f"No collection is specified. Using current active collection {self.active_collection.name}."
|
|
771
|
+
)
|
|
772
|
+
else:
|
|
773
|
+
if not (self.active_collection and self.active_collection.name == collection_name):
|
|
774
|
+
self.active_collection = Collection(
|
|
775
|
+
client=self.client,
|
|
776
|
+
collection_name=collection_name,
|
|
777
|
+
embedding_function=self.embedding_function,
|
|
778
|
+
)
|
|
779
|
+
return self.active_collection
|
|
780
|
+
|
|
781
|
+
def delete_collection(self, collection_name: str) -> None:
|
|
782
|
+
"""Delete the collection from the vector database.
|
|
783
|
+
|
|
784
|
+
Args:
|
|
785
|
+
collection_name: str | The name of the collection.
|
|
786
|
+
|
|
787
|
+
Returns:
|
|
788
|
+
None
|
|
789
|
+
"""
|
|
790
|
+
if self.active_collection:
|
|
791
|
+
self.active_collection.delete_collection(collection_name)
|
|
792
|
+
else:
|
|
793
|
+
collection = self.get_collection(collection_name)
|
|
794
|
+
collection.delete_collection(collection_name)
|
|
795
|
+
if self.active_collection and self.active_collection.name == collection_name:
|
|
796
|
+
self.active_collection = None
|
|
797
|
+
|
|
798
|
+
def _batch_insert(
|
|
799
|
+
self, collection: Collection, embeddings=None, ids=None, metadatas=None, documents=None, upsert=False
|
|
800
|
+
) -> None:
|
|
801
|
+
batch_size = int(PGVECTOR_MAX_BATCH_SIZE)
|
|
802
|
+
default_metadata = {"hnsw:space": "ip", "hnsw:construction_ef": 32, "hnsw:M": 16}
|
|
803
|
+
default_metadatas = [default_metadata] * min(batch_size, len(documents))
|
|
804
|
+
for i in range(0, len(documents), min(batch_size, len(documents))):
|
|
805
|
+
end_idx = i + min(batch_size, len(documents) - i)
|
|
806
|
+
collection_kwargs = {
|
|
807
|
+
"documents": documents[i:end_idx],
|
|
808
|
+
"ids": ids[i:end_idx],
|
|
809
|
+
"metadatas": metadatas[i:end_idx] if metadatas else default_metadatas,
|
|
810
|
+
"embeddings": embeddings[i:end_idx] if embeddings else None,
|
|
811
|
+
}
|
|
812
|
+
if upsert:
|
|
813
|
+
collection.upsert(**collection_kwargs)
|
|
814
|
+
else:
|
|
815
|
+
collection.add(**collection_kwargs)
|
|
816
|
+
|
|
817
|
+
def insert_docs(self, docs: list[Document], collection_name: str = None, upsert: bool = False) -> None:
|
|
818
|
+
"""Insert documents into the collection of the vector database.
|
|
819
|
+
|
|
820
|
+
Args:
|
|
821
|
+
docs: List[Document] | A list of documents. Each document is a TypedDict `Document`.
|
|
822
|
+
collection_name: str | The name of the collection. Default is None.
|
|
823
|
+
upsert: bool | Whether to update the document if it exists. Default is False.
|
|
824
|
+
kwargs: Dict | Additional keyword arguments.
|
|
825
|
+
|
|
826
|
+
Returns:
|
|
827
|
+
None
|
|
828
|
+
"""
|
|
829
|
+
if not docs:
|
|
830
|
+
return
|
|
831
|
+
if docs[0].get("content") is None:
|
|
832
|
+
raise ValueError("The document content is required.")
|
|
833
|
+
if docs[0].get("id") is None:
|
|
834
|
+
raise ValueError("The document id is required.")
|
|
835
|
+
documents = [doc.get("content") for doc in docs]
|
|
836
|
+
ids = [doc.get("id") for doc in docs]
|
|
837
|
+
|
|
838
|
+
collection = self.get_collection(collection_name)
|
|
839
|
+
if docs[0].get("embedding") is None:
|
|
840
|
+
logger.debug(
|
|
841
|
+
"No content embedding is provided. "
|
|
842
|
+
"Will use the VectorDB's embedding function to generate the content embedding."
|
|
843
|
+
)
|
|
844
|
+
embeddings = None
|
|
845
|
+
else:
|
|
846
|
+
embeddings = [doc.get("embedding") for doc in docs]
|
|
847
|
+
metadatas = None if docs[0].get("metadata") is None else [doc.get("metadata") for doc in docs]
|
|
848
|
+
|
|
849
|
+
self._batch_insert(collection, embeddings, ids, metadatas, documents, upsert)
|
|
850
|
+
|
|
851
|
+
def update_docs(self, docs: list[Document], collection_name: str = None) -> None:
|
|
852
|
+
"""Update documents in the collection of the vector database.
|
|
853
|
+
|
|
854
|
+
Args:
|
|
855
|
+
docs: List[Document] | A list of documents.
|
|
856
|
+
collection_name: str | The name of the collection. Default is None.
|
|
857
|
+
|
|
858
|
+
Returns:
|
|
859
|
+
None
|
|
860
|
+
"""
|
|
861
|
+
self.insert_docs(docs, collection_name, upsert=True)
|
|
862
|
+
|
|
863
|
+
def delete_docs(self, ids: list[ItemID], collection_name: str = None) -> None:
|
|
864
|
+
"""Delete documents from the collection of the vector database.
|
|
865
|
+
|
|
866
|
+
Args:
|
|
867
|
+
ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`.
|
|
868
|
+
collection_name: str | The name of the collection. Default is None.
|
|
869
|
+
kwargs: Dict | Additional keyword arguments.
|
|
870
|
+
|
|
871
|
+
Returns:
|
|
872
|
+
None
|
|
873
|
+
"""
|
|
874
|
+
collection = self.get_collection(collection_name)
|
|
875
|
+
collection.delete(ids=ids, collection_name=collection_name)
|
|
876
|
+
|
|
877
|
+
def retrieve_docs(
|
|
878
|
+
self,
|
|
879
|
+
queries: list[str],
|
|
880
|
+
collection_name: str = None,
|
|
881
|
+
n_results: int = 10,
|
|
882
|
+
distance_threshold: float = -1,
|
|
883
|
+
) -> QueryResults:
|
|
884
|
+
"""Retrieve documents from the collection of the vector database based on the queries.
|
|
885
|
+
|
|
886
|
+
Args:
|
|
887
|
+
queries: List[str] | A list of queries. Each query is a string.
|
|
888
|
+
collection_name: str | The name of the collection. Default is None.
|
|
889
|
+
n_results: int | The number of relevant documents to return. Default is 10.
|
|
890
|
+
distance_threshold: float | The threshold for the distance score, only distance smaller than it will be
|
|
891
|
+
returned. Don't filter with it if `< 0`. Default is -1.
|
|
892
|
+
kwargs: Dict | Additional keyword arguments.
|
|
893
|
+
|
|
894
|
+
Returns:
|
|
895
|
+
QueryResults | The query results. Each query result is a list of list of tuples containing the document and
|
|
896
|
+
the distance.
|
|
897
|
+
"""
|
|
898
|
+
collection = self.get_collection(collection_name)
|
|
899
|
+
if isinstance(queries, str):
|
|
900
|
+
queries = [queries]
|
|
901
|
+
results = collection.query(
|
|
902
|
+
query_texts=queries,
|
|
903
|
+
n_results=n_results,
|
|
904
|
+
distance_threshold=distance_threshold,
|
|
905
|
+
)
|
|
906
|
+
logger.debug(f"Retrieve Docs Results:\n{results}")
|
|
907
|
+
return results
|
|
908
|
+
|
|
909
|
+
def get_docs_by_ids(
|
|
910
|
+
self, ids: list[ItemID] = None, collection_name: str = None, include=None, **kwargs
|
|
911
|
+
) -> list[Document]:
|
|
912
|
+
"""Retrieve documents from the collection of the vector database based on the ids.
|
|
913
|
+
|
|
914
|
+
Args:
|
|
915
|
+
ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None.
|
|
916
|
+
collection_name: str | The name of the collection. Default is None.
|
|
917
|
+
include: List[str] | The fields to include. Default is None.
|
|
918
|
+
If None, will include ["metadatas", "documents"], ids will always be included.
|
|
919
|
+
kwargs: dict | Additional keyword arguments.
|
|
920
|
+
|
|
921
|
+
Returns:
|
|
922
|
+
List[Document] | The results.
|
|
923
|
+
"""
|
|
924
|
+
collection = self.get_collection(collection_name)
|
|
925
|
+
include = include if include else ["metadatas", "documents"]
|
|
926
|
+
results = collection.get(ids, include=include, **kwargs)
|
|
927
|
+
logger.debug(f"Retrieve Documents by ID Results:\n{results}")
|
|
928
|
+
return results
|