aiagents4pharma 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiagents4pharma/__init__.py +11 -0
- aiagents4pharma/talk2aiagents4pharma/.dockerignore +13 -0
- aiagents4pharma/talk2aiagents4pharma/Dockerfile +133 -0
- aiagents4pharma/talk2aiagents4pharma/README.md +1 -0
- aiagents4pharma/talk2aiagents4pharma/__init__.py +5 -0
- aiagents4pharma/talk2aiagents4pharma/agents/__init__.py +6 -0
- aiagents4pharma/talk2aiagents4pharma/agents/main_agent.py +70 -0
- aiagents4pharma/talk2aiagents4pharma/configs/__init__.py +5 -0
- aiagents4pharma/talk2aiagents4pharma/configs/agents/__init__.py +5 -0
- aiagents4pharma/talk2aiagents4pharma/configs/agents/main_agent/default.yaml +29 -0
- aiagents4pharma/talk2aiagents4pharma/configs/app/__init__.py +0 -0
- aiagents4pharma/talk2aiagents4pharma/configs/app/frontend/__init__.py +0 -0
- aiagents4pharma/talk2aiagents4pharma/configs/app/frontend/default.yaml +102 -0
- aiagents4pharma/talk2aiagents4pharma/configs/config.yaml +4 -0
- aiagents4pharma/talk2aiagents4pharma/docker-compose/cpu/.env.example +23 -0
- aiagents4pharma/talk2aiagents4pharma/docker-compose/cpu/docker-compose.yml +93 -0
- aiagents4pharma/talk2aiagents4pharma/docker-compose/gpu/.env.example +23 -0
- aiagents4pharma/talk2aiagents4pharma/docker-compose/gpu/docker-compose.yml +108 -0
- aiagents4pharma/talk2aiagents4pharma/install.md +154 -0
- aiagents4pharma/talk2aiagents4pharma/states/__init__.py +5 -0
- aiagents4pharma/talk2aiagents4pharma/states/state_talk2aiagents4pharma.py +18 -0
- aiagents4pharma/talk2aiagents4pharma/tests/__init__.py +3 -0
- aiagents4pharma/talk2aiagents4pharma/tests/test_main_agent.py +312 -0
- aiagents4pharma/talk2biomodels/.dockerignore +13 -0
- aiagents4pharma/talk2biomodels/Dockerfile +104 -0
- aiagents4pharma/talk2biomodels/README.md +1 -0
- aiagents4pharma/talk2biomodels/__init__.py +5 -0
- aiagents4pharma/talk2biomodels/agents/__init__.py +6 -0
- aiagents4pharma/talk2biomodels/agents/t2b_agent.py +104 -0
- aiagents4pharma/talk2biomodels/api/__init__.py +5 -0
- aiagents4pharma/talk2biomodels/api/ols.py +75 -0
- aiagents4pharma/talk2biomodels/api/uniprot.py +36 -0
- aiagents4pharma/talk2biomodels/configs/__init__.py +5 -0
- aiagents4pharma/talk2biomodels/configs/agents/__init__.py +5 -0
- aiagents4pharma/talk2biomodels/configs/agents/t2b_agent/__init__.py +3 -0
- aiagents4pharma/talk2biomodels/configs/agents/t2b_agent/default.yaml +14 -0
- aiagents4pharma/talk2biomodels/configs/app/__init__.py +0 -0
- aiagents4pharma/talk2biomodels/configs/app/frontend/__init__.py +0 -0
- aiagents4pharma/talk2biomodels/configs/app/frontend/default.yaml +72 -0
- aiagents4pharma/talk2biomodels/configs/config.yaml +7 -0
- aiagents4pharma/talk2biomodels/configs/tools/__init__.py +5 -0
- aiagents4pharma/talk2biomodels/configs/tools/ask_question/__init__.py +3 -0
- aiagents4pharma/talk2biomodels/configs/tools/ask_question/default.yaml +30 -0
- aiagents4pharma/talk2biomodels/configs/tools/custom_plotter/__init__.py +3 -0
- aiagents4pharma/talk2biomodels/configs/tools/custom_plotter/default.yaml +8 -0
- aiagents4pharma/talk2biomodels/configs/tools/get_annotation/__init__.py +3 -0
- aiagents4pharma/talk2biomodels/configs/tools/get_annotation/default.yaml +8 -0
- aiagents4pharma/talk2biomodels/install.md +63 -0
- aiagents4pharma/talk2biomodels/models/__init__.py +5 -0
- aiagents4pharma/talk2biomodels/models/basico_model.py +125 -0
- aiagents4pharma/talk2biomodels/models/sys_bio_model.py +60 -0
- aiagents4pharma/talk2biomodels/states/__init__.py +6 -0
- aiagents4pharma/talk2biomodels/states/state_talk2biomodels.py +49 -0
- aiagents4pharma/talk2biomodels/tests/BIOMD0000000449_url.xml +1585 -0
- aiagents4pharma/talk2biomodels/tests/__init__.py +3 -0
- aiagents4pharma/talk2biomodels/tests/article_on_model_537.pdf +0 -0
- aiagents4pharma/talk2biomodels/tests/test_api.py +31 -0
- aiagents4pharma/talk2biomodels/tests/test_ask_question.py +42 -0
- aiagents4pharma/talk2biomodels/tests/test_basico_model.py +67 -0
- aiagents4pharma/talk2biomodels/tests/test_get_annotation.py +190 -0
- aiagents4pharma/talk2biomodels/tests/test_getmodelinfo.py +92 -0
- aiagents4pharma/talk2biomodels/tests/test_integration.py +116 -0
- aiagents4pharma/talk2biomodels/tests/test_load_biomodel.py +35 -0
- aiagents4pharma/talk2biomodels/tests/test_param_scan.py +71 -0
- aiagents4pharma/talk2biomodels/tests/test_query_article.py +184 -0
- aiagents4pharma/talk2biomodels/tests/test_save_model.py +47 -0
- aiagents4pharma/talk2biomodels/tests/test_search_models.py +35 -0
- aiagents4pharma/talk2biomodels/tests/test_simulate_model.py +44 -0
- aiagents4pharma/talk2biomodels/tests/test_steady_state.py +86 -0
- aiagents4pharma/talk2biomodels/tests/test_sys_bio_model.py +67 -0
- aiagents4pharma/talk2biomodels/tools/__init__.py +17 -0
- aiagents4pharma/talk2biomodels/tools/ask_question.py +125 -0
- aiagents4pharma/talk2biomodels/tools/custom_plotter.py +165 -0
- aiagents4pharma/talk2biomodels/tools/get_annotation.py +342 -0
- aiagents4pharma/talk2biomodels/tools/get_modelinfo.py +159 -0
- aiagents4pharma/talk2biomodels/tools/load_arguments.py +134 -0
- aiagents4pharma/talk2biomodels/tools/load_biomodel.py +44 -0
- aiagents4pharma/talk2biomodels/tools/parameter_scan.py +310 -0
- aiagents4pharma/talk2biomodels/tools/query_article.py +64 -0
- aiagents4pharma/talk2biomodels/tools/save_model.py +98 -0
- aiagents4pharma/talk2biomodels/tools/search_models.py +96 -0
- aiagents4pharma/talk2biomodels/tools/simulate_model.py +137 -0
- aiagents4pharma/talk2biomodels/tools/steady_state.py +187 -0
- aiagents4pharma/talk2biomodels/tools/utils.py +23 -0
- aiagents4pharma/talk2cells/README.md +1 -0
- aiagents4pharma/talk2cells/__init__.py +5 -0
- aiagents4pharma/talk2cells/agents/__init__.py +6 -0
- aiagents4pharma/talk2cells/agents/scp_agent.py +87 -0
- aiagents4pharma/talk2cells/states/__init__.py +6 -0
- aiagents4pharma/talk2cells/states/state_talk2cells.py +15 -0
- aiagents4pharma/talk2cells/tests/scp_agent/test_scp_agent.py +22 -0
- aiagents4pharma/talk2cells/tools/__init__.py +6 -0
- aiagents4pharma/talk2cells/tools/scp_agent/__init__.py +6 -0
- aiagents4pharma/talk2cells/tools/scp_agent/display_studies.py +27 -0
- aiagents4pharma/talk2cells/tools/scp_agent/search_studies.py +78 -0
- aiagents4pharma/talk2knowledgegraphs/.dockerignore +13 -0
- aiagents4pharma/talk2knowledgegraphs/Dockerfile +131 -0
- aiagents4pharma/talk2knowledgegraphs/README.md +1 -0
- aiagents4pharma/talk2knowledgegraphs/__init__.py +5 -0
- aiagents4pharma/talk2knowledgegraphs/agents/__init__.py +5 -0
- aiagents4pharma/talk2knowledgegraphs/agents/t2kg_agent.py +99 -0
- aiagents4pharma/talk2knowledgegraphs/configs/__init__.py +5 -0
- aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/default.yaml +62 -0
- aiagents4pharma/talk2knowledgegraphs/configs/app/__init__.py +5 -0
- aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/default.yaml +79 -0
- aiagents4pharma/talk2knowledgegraphs/configs/config.yaml +13 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/__init__.py +5 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/default.yaml +24 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/multimodal_subgraph_extraction/__init__.py +0 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/multimodal_subgraph_extraction/default.yaml +33 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/default.yaml +43 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/default.yaml +9 -0
- aiagents4pharma/talk2knowledgegraphs/configs/utils/database/milvus/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/utils/database/milvus/default.yaml +61 -0
- aiagents4pharma/talk2knowledgegraphs/configs/utils/enrichments/ols_terms/default.yaml +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/utils/enrichments/reactome_pathways/default.yaml +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/utils/enrichments/uniprot_proteins/default.yaml +6 -0
- aiagents4pharma/talk2knowledgegraphs/configs/utils/pubchem_utils/default.yaml +5 -0
- aiagents4pharma/talk2knowledgegraphs/datasets/__init__.py +5 -0
- aiagents4pharma/talk2knowledgegraphs/datasets/biobridge_primekg.py +607 -0
- aiagents4pharma/talk2knowledgegraphs/datasets/dataset.py +25 -0
- aiagents4pharma/talk2knowledgegraphs/datasets/primekg.py +212 -0
- aiagents4pharma/talk2knowledgegraphs/datasets/starkqa_primekg.py +210 -0
- aiagents4pharma/talk2knowledgegraphs/docker-compose/cpu/.env.example +23 -0
- aiagents4pharma/talk2knowledgegraphs/docker-compose/cpu/docker-compose.yml +93 -0
- aiagents4pharma/talk2knowledgegraphs/docker-compose/gpu/.env.example +23 -0
- aiagents4pharma/talk2knowledgegraphs/docker-compose/gpu/docker-compose.yml +108 -0
- aiagents4pharma/talk2knowledgegraphs/entrypoint.sh +180 -0
- aiagents4pharma/talk2knowledgegraphs/install.md +165 -0
- aiagents4pharma/talk2knowledgegraphs/milvus_data_dump.py +886 -0
- aiagents4pharma/talk2knowledgegraphs/states/__init__.py +5 -0
- aiagents4pharma/talk2knowledgegraphs/states/state_talk2knowledgegraphs.py +40 -0
- aiagents4pharma/talk2knowledgegraphs/tests/__init__.py +0 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_agents_t2kg_agent.py +318 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_biobridge_primekg.py +248 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_dataset.py +33 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_primekg.py +86 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_starkqa_primekg.py +125 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_graphrag_reasoning.py +257 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_milvus_multimodal_subgraph_extraction.py +1444 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_multimodal_subgraph_extraction.py +159 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_extraction.py +152 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_summarization.py +201 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_database_milvus_connection_manager.py +812 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_embeddings.py +51 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_huggingface.py +49 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_nim_molmim.py +59 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_ollama.py +63 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_sentencetransformer.py +47 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_enrichments.py +40 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ollama.py +94 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ols.py +70 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_pubchem.py +45 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_reactome.py +44 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_uniprot.py +48 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_extractions_milvus_multimodal_pcst.py +759 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_kg_utils.py +78 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_pubchem_utils.py +123 -0
- aiagents4pharma/talk2knowledgegraphs/tools/__init__.py +11 -0
- aiagents4pharma/talk2knowledgegraphs/tools/graphrag_reasoning.py +138 -0
- aiagents4pharma/talk2knowledgegraphs/tools/load_arguments.py +22 -0
- aiagents4pharma/talk2knowledgegraphs/tools/milvus_multimodal_subgraph_extraction.py +965 -0
- aiagents4pharma/talk2knowledgegraphs/tools/multimodal_subgraph_extraction.py +374 -0
- aiagents4pharma/talk2knowledgegraphs/tools/subgraph_extraction.py +291 -0
- aiagents4pharma/talk2knowledgegraphs/tools/subgraph_summarization.py +123 -0
- aiagents4pharma/talk2knowledgegraphs/utils/__init__.py +5 -0
- aiagents4pharma/talk2knowledgegraphs/utils/database/__init__.py +5 -0
- aiagents4pharma/talk2knowledgegraphs/utils/database/milvus_connection_manager.py +586 -0
- aiagents4pharma/talk2knowledgegraphs/utils/embeddings/__init__.py +5 -0
- aiagents4pharma/talk2knowledgegraphs/utils/embeddings/embeddings.py +81 -0
- aiagents4pharma/talk2knowledgegraphs/utils/embeddings/huggingface.py +111 -0
- aiagents4pharma/talk2knowledgegraphs/utils/embeddings/nim_molmim.py +54 -0
- aiagents4pharma/talk2knowledgegraphs/utils/embeddings/ollama.py +87 -0
- aiagents4pharma/talk2knowledgegraphs/utils/embeddings/sentence_transformer.py +73 -0
- aiagents4pharma/talk2knowledgegraphs/utils/enrichments/__init__.py +12 -0
- aiagents4pharma/talk2knowledgegraphs/utils/enrichments/enrichments.py +37 -0
- aiagents4pharma/talk2knowledgegraphs/utils/enrichments/ollama.py +129 -0
- aiagents4pharma/talk2knowledgegraphs/utils/enrichments/ols_terms.py +89 -0
- aiagents4pharma/talk2knowledgegraphs/utils/enrichments/pubchem_strings.py +78 -0
- aiagents4pharma/talk2knowledgegraphs/utils/enrichments/reactome_pathways.py +71 -0
- aiagents4pharma/talk2knowledgegraphs/utils/enrichments/uniprot_proteins.py +98 -0
- aiagents4pharma/talk2knowledgegraphs/utils/extractions/__init__.py +5 -0
- aiagents4pharma/talk2knowledgegraphs/utils/extractions/milvus_multimodal_pcst.py +762 -0
- aiagents4pharma/talk2knowledgegraphs/utils/extractions/multimodal_pcst.py +298 -0
- aiagents4pharma/talk2knowledgegraphs/utils/extractions/pcst.py +229 -0
- aiagents4pharma/talk2knowledgegraphs/utils/kg_utils.py +67 -0
- aiagents4pharma/talk2knowledgegraphs/utils/pubchem_utils.py +104 -0
- aiagents4pharma/talk2scholars/.dockerignore +13 -0
- aiagents4pharma/talk2scholars/Dockerfile +104 -0
- aiagents4pharma/talk2scholars/README.md +1 -0
- aiagents4pharma/talk2scholars/__init__.py +7 -0
- aiagents4pharma/talk2scholars/agents/__init__.py +13 -0
- aiagents4pharma/talk2scholars/agents/main_agent.py +89 -0
- aiagents4pharma/talk2scholars/agents/paper_download_agent.py +96 -0
- aiagents4pharma/talk2scholars/agents/pdf_agent.py +101 -0
- aiagents4pharma/talk2scholars/agents/s2_agent.py +135 -0
- aiagents4pharma/talk2scholars/agents/zotero_agent.py +127 -0
- aiagents4pharma/talk2scholars/configs/__init__.py +7 -0
- aiagents4pharma/talk2scholars/configs/agents/__init__.py +7 -0
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/__init__.py +7 -0
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/main_agent/__init__.py +3 -0
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/main_agent/default.yaml +52 -0
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/paper_download_agent/__init__.py +3 -0
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/paper_download_agent/default.yaml +19 -0
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/pdf_agent/__init__.py +3 -0
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/pdf_agent/default.yaml +19 -0
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/s2_agent/__init__.py +3 -0
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/s2_agent/default.yaml +44 -0
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/zotero_agent/__init__.py +3 -0
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/zotero_agent/default.yaml +19 -0
- aiagents4pharma/talk2scholars/configs/app/__init__.py +7 -0
- aiagents4pharma/talk2scholars/configs/app/frontend/__init__.py +3 -0
- aiagents4pharma/talk2scholars/configs/app/frontend/default.yaml +72 -0
- aiagents4pharma/talk2scholars/configs/config.yaml +16 -0
- aiagents4pharma/talk2scholars/configs/tools/__init__.py +21 -0
- aiagents4pharma/talk2scholars/configs/tools/multi_paper_recommendation/__init__.py +3 -0
- aiagents4pharma/talk2scholars/configs/tools/multi_paper_recommendation/default.yaml +26 -0
- aiagents4pharma/talk2scholars/configs/tools/paper_download/__init__.py +3 -0
- aiagents4pharma/talk2scholars/configs/tools/paper_download/default.yaml +124 -0
- aiagents4pharma/talk2scholars/configs/tools/question_and_answer/__init__.py +3 -0
- aiagents4pharma/talk2scholars/configs/tools/question_and_answer/default.yaml +62 -0
- aiagents4pharma/talk2scholars/configs/tools/retrieve_semantic_scholar_paper_id/__init__.py +3 -0
- aiagents4pharma/talk2scholars/configs/tools/retrieve_semantic_scholar_paper_id/default.yaml +12 -0
- aiagents4pharma/talk2scholars/configs/tools/search/__init__.py +3 -0
- aiagents4pharma/talk2scholars/configs/tools/search/default.yaml +26 -0
- aiagents4pharma/talk2scholars/configs/tools/single_paper_recommendation/__init__.py +3 -0
- aiagents4pharma/talk2scholars/configs/tools/single_paper_recommendation/default.yaml +26 -0
- aiagents4pharma/talk2scholars/configs/tools/zotero_read/__init__.py +3 -0
- aiagents4pharma/talk2scholars/configs/tools/zotero_read/default.yaml +57 -0
- aiagents4pharma/talk2scholars/configs/tools/zotero_write/__inti__.py +3 -0
- aiagents4pharma/talk2scholars/configs/tools/zotero_write/default.yaml +55 -0
- aiagents4pharma/talk2scholars/docker-compose/cpu/.env.example +21 -0
- aiagents4pharma/talk2scholars/docker-compose/cpu/docker-compose.yml +90 -0
- aiagents4pharma/talk2scholars/docker-compose/gpu/.env.example +21 -0
- aiagents4pharma/talk2scholars/docker-compose/gpu/docker-compose.yml +105 -0
- aiagents4pharma/talk2scholars/install.md +122 -0
- aiagents4pharma/talk2scholars/state/__init__.py +7 -0
- aiagents4pharma/talk2scholars/state/state_talk2scholars.py +98 -0
- aiagents4pharma/talk2scholars/tests/__init__.py +3 -0
- aiagents4pharma/talk2scholars/tests/test_agents_main_agent.py +256 -0
- aiagents4pharma/talk2scholars/tests/test_agents_paper_agents_download_agent.py +139 -0
- aiagents4pharma/talk2scholars/tests/test_agents_pdf_agent.py +114 -0
- aiagents4pharma/talk2scholars/tests/test_agents_s2_agent.py +198 -0
- aiagents4pharma/talk2scholars/tests/test_agents_zotero_agent.py +160 -0
- aiagents4pharma/talk2scholars/tests/test_s2_tools_display_dataframe.py +91 -0
- aiagents4pharma/talk2scholars/tests/test_s2_tools_query_dataframe.py +191 -0
- aiagents4pharma/talk2scholars/tests/test_states_state.py +38 -0
- aiagents4pharma/talk2scholars/tests/test_tools_paper_downloader.py +507 -0
- aiagents4pharma/talk2scholars/tests/test_tools_question_and_answer_tool.py +105 -0
- aiagents4pharma/talk2scholars/tests/test_tools_s2_multi.py +307 -0
- aiagents4pharma/talk2scholars/tests/test_tools_s2_retrieve.py +67 -0
- aiagents4pharma/talk2scholars/tests/test_tools_s2_search.py +286 -0
- aiagents4pharma/talk2scholars/tests/test_tools_s2_single.py +298 -0
- aiagents4pharma/talk2scholars/tests/test_utils_arxiv_downloader.py +469 -0
- aiagents4pharma/talk2scholars/tests/test_utils_base_paper_downloader.py +598 -0
- aiagents4pharma/talk2scholars/tests/test_utils_biorxiv_downloader.py +669 -0
- aiagents4pharma/talk2scholars/tests/test_utils_medrxiv_downloader.py +500 -0
- aiagents4pharma/talk2scholars/tests/test_utils_nvidia_nim_reranker.py +117 -0
- aiagents4pharma/talk2scholars/tests/test_utils_pdf_answer_formatter.py +67 -0
- aiagents4pharma/talk2scholars/tests/test_utils_pdf_batch_processor.py +92 -0
- aiagents4pharma/talk2scholars/tests/test_utils_pdf_collection_manager.py +173 -0
- aiagents4pharma/talk2scholars/tests/test_utils_pdf_document_processor.py +68 -0
- aiagents4pharma/talk2scholars/tests/test_utils_pdf_generate_answer.py +72 -0
- aiagents4pharma/talk2scholars/tests/test_utils_pdf_gpu_detection.py +129 -0
- aiagents4pharma/talk2scholars/tests/test_utils_pdf_paper_loader.py +116 -0
- aiagents4pharma/talk2scholars/tests/test_utils_pdf_rag_pipeline.py +88 -0
- aiagents4pharma/talk2scholars/tests/test_utils_pdf_retrieve_chunks.py +190 -0
- aiagents4pharma/talk2scholars/tests/test_utils_pdf_singleton_manager.py +159 -0
- aiagents4pharma/talk2scholars/tests/test_utils_pdf_vector_normalization.py +121 -0
- aiagents4pharma/talk2scholars/tests/test_utils_pdf_vector_store.py +406 -0
- aiagents4pharma/talk2scholars/tests/test_utils_pubmed_downloader.py +1007 -0
- aiagents4pharma/talk2scholars/tests/test_utils_read_helper_utils.py +106 -0
- aiagents4pharma/talk2scholars/tests/test_utils_s2_utils_ext_ids.py +403 -0
- aiagents4pharma/talk2scholars/tests/test_utils_tool_helper_utils.py +85 -0
- aiagents4pharma/talk2scholars/tests/test_utils_zotero_human_in_the_loop.py +266 -0
- aiagents4pharma/talk2scholars/tests/test_utils_zotero_path.py +496 -0
- aiagents4pharma/talk2scholars/tests/test_utils_zotero_pdf_downloader_utils.py +46 -0
- aiagents4pharma/talk2scholars/tests/test_utils_zotero_read.py +743 -0
- aiagents4pharma/talk2scholars/tests/test_utils_zotero_write.py +151 -0
- aiagents4pharma/talk2scholars/tools/__init__.py +9 -0
- aiagents4pharma/talk2scholars/tools/paper_download/__init__.py +12 -0
- aiagents4pharma/talk2scholars/tools/paper_download/paper_downloader.py +442 -0
- aiagents4pharma/talk2scholars/tools/paper_download/utils/__init__.py +22 -0
- aiagents4pharma/talk2scholars/tools/paper_download/utils/arxiv_downloader.py +207 -0
- aiagents4pharma/talk2scholars/tools/paper_download/utils/base_paper_downloader.py +336 -0
- aiagents4pharma/talk2scholars/tools/paper_download/utils/biorxiv_downloader.py +313 -0
- aiagents4pharma/talk2scholars/tools/paper_download/utils/medrxiv_downloader.py +196 -0
- aiagents4pharma/talk2scholars/tools/paper_download/utils/pubmed_downloader.py +323 -0
- aiagents4pharma/talk2scholars/tools/pdf/__init__.py +7 -0
- aiagents4pharma/talk2scholars/tools/pdf/question_and_answer.py +170 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/__init__.py +37 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/answer_formatter.py +62 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/batch_processor.py +198 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/collection_manager.py +172 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/document_processor.py +76 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/generate_answer.py +97 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/get_vectorstore.py +59 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/gpu_detection.py +150 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/nvidia_nim_reranker.py +97 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/paper_loader.py +123 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/rag_pipeline.py +113 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/retrieve_chunks.py +197 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/singleton_manager.py +140 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/tool_helper.py +86 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/vector_normalization.py +150 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/vector_store.py +327 -0
- aiagents4pharma/talk2scholars/tools/s2/__init__.py +21 -0
- aiagents4pharma/talk2scholars/tools/s2/display_dataframe.py +110 -0
- aiagents4pharma/talk2scholars/tools/s2/multi_paper_rec.py +111 -0
- aiagents4pharma/talk2scholars/tools/s2/query_dataframe.py +233 -0
- aiagents4pharma/talk2scholars/tools/s2/retrieve_semantic_scholar_paper_id.py +128 -0
- aiagents4pharma/talk2scholars/tools/s2/search.py +101 -0
- aiagents4pharma/talk2scholars/tools/s2/single_paper_rec.py +102 -0
- aiagents4pharma/talk2scholars/tools/s2/utils/__init__.py +5 -0
- aiagents4pharma/talk2scholars/tools/s2/utils/multi_helper.py +223 -0
- aiagents4pharma/talk2scholars/tools/s2/utils/search_helper.py +205 -0
- aiagents4pharma/talk2scholars/tools/s2/utils/single_helper.py +216 -0
- aiagents4pharma/talk2scholars/tools/zotero/__init__.py +7 -0
- aiagents4pharma/talk2scholars/tools/zotero/utils/__init__.py +7 -0
- aiagents4pharma/talk2scholars/tools/zotero/utils/read_helper.py +270 -0
- aiagents4pharma/talk2scholars/tools/zotero/utils/review_helper.py +74 -0
- aiagents4pharma/talk2scholars/tools/zotero/utils/write_helper.py +194 -0
- aiagents4pharma/talk2scholars/tools/zotero/utils/zotero_path.py +180 -0
- aiagents4pharma/talk2scholars/tools/zotero/utils/zotero_pdf_downloader.py +133 -0
- aiagents4pharma/talk2scholars/tools/zotero/zotero_read.py +105 -0
- aiagents4pharma/talk2scholars/tools/zotero/zotero_review.py +162 -0
- aiagents4pharma/talk2scholars/tools/zotero/zotero_write.py +91 -0
- aiagents4pharma-0.0.0.dist-info/METADATA +335 -0
- aiagents4pharma-0.0.0.dist-info/RECORD +336 -0
- aiagents4pharma-0.0.0.dist-info/WHEEL +4 -0
- aiagents4pharma-0.0.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,965 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tool for performing multimodal subgraph extraction.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import concurrent.futures
|
|
7
|
+
import logging
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from typing import Annotated
|
|
10
|
+
|
|
11
|
+
import hydra
|
|
12
|
+
import pandas as pd
|
|
13
|
+
import pcst_fast
|
|
14
|
+
from langchain_core.messages import ToolMessage
|
|
15
|
+
from langchain_core.tools import BaseTool
|
|
16
|
+
from langchain_core.tools.base import InjectedToolCallId
|
|
17
|
+
from langgraph.prebuilt import InjectedState
|
|
18
|
+
from langgraph.types import Command
|
|
19
|
+
from pydantic import BaseModel, Field
|
|
20
|
+
from pymilvus import Collection
|
|
21
|
+
|
|
22
|
+
from ..utils.database import MilvusConnectionManager
|
|
23
|
+
from ..utils.database.milvus_connection_manager import QueryParams
|
|
24
|
+
from ..utils.extractions.milvus_multimodal_pcst import (
|
|
25
|
+
DynamicLibraryLoader,
|
|
26
|
+
MultimodalPCSTPruning,
|
|
27
|
+
SystemDetector,
|
|
28
|
+
)
|
|
29
|
+
from .load_arguments import ArgumentData
|
|
30
|
+
|
|
31
|
+
# pylint: disable=too-many-lines
|
|
32
|
+
# Initialize logger
|
|
33
|
+
logging.basicConfig(level=logging.INFO)
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class ExtractionParams:
|
|
39
|
+
"""Parameters for subgraph extraction."""
|
|
40
|
+
|
|
41
|
+
state: dict
|
|
42
|
+
cfg: dict
|
|
43
|
+
cfg_db: dict
|
|
44
|
+
query_df: object
|
|
45
|
+
connection_manager: object
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class MultimodalSubgraphExtractionInput(BaseModel):
|
|
49
|
+
"""
|
|
50
|
+
MultimodalSubgraphExtractionInput is a Pydantic model representing an input
|
|
51
|
+
for extracting a subgraph.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
prompt: Prompt to interact with the backend.
|
|
55
|
+
tool_call_id: Tool call ID.
|
|
56
|
+
state: Injected state.
|
|
57
|
+
arg_data: Argument for analytical process over graph data.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
tool_call_id: Annotated[str, InjectedToolCallId] = Field(description="Tool call ID.")
|
|
61
|
+
state: Annotated[dict, InjectedState] = Field(description="Injected state.")
|
|
62
|
+
prompt: str = Field(description="Prompt to interact with the backend.")
|
|
63
|
+
arg_data: ArgumentData = Field(description="Experiment over graph data.", default=None)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class MultimodalSubgraphExtractionTool(BaseTool):
|
|
67
|
+
"""
|
|
68
|
+
This tool performs subgraph extraction based on user's prompt by taking into account
|
|
69
|
+
the top-k nodes and edges.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
name: str = "subgraph_extraction"
|
|
73
|
+
description: str = "A tool for subgraph extraction based on user's prompt."
|
|
74
|
+
args_schema: type[BaseModel] = MultimodalSubgraphExtractionInput
|
|
75
|
+
|
|
76
|
+
def __init__(self, **kwargs):
|
|
77
|
+
super().__init__(**kwargs)
|
|
78
|
+
# Initialize hardware detection and dynamic library loading
|
|
79
|
+
object.__setattr__(self, "detector", SystemDetector())
|
|
80
|
+
object.__setattr__(self, "loader", DynamicLibraryLoader(self.detector))
|
|
81
|
+
logger.info(
|
|
82
|
+
"MultimodalSubgraphExtractionTool initialized with %s mode",
|
|
83
|
+
"GPU" if self.loader.use_gpu else "CPU",
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
def _read_multimodal_files(self, state: Annotated[dict, InjectedState]):
|
|
87
|
+
"""
|
|
88
|
+
Read the uploaded multimodal files and return a DataFrame.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
state: The injected state for the tool.
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
A DataFrame containing the multimodal files.
|
|
95
|
+
"""
|
|
96
|
+
multimodal_df = self.loader.df.DataFrame({"name": [], "node_type": []})
|
|
97
|
+
|
|
98
|
+
# Loop over the uploaded files and find multimodal files
|
|
99
|
+
logger.log(logging.INFO, "Looping over uploaded files")
|
|
100
|
+
for i in range(len(state["uploaded_files"])):
|
|
101
|
+
# Check if multimodal file is uploaded
|
|
102
|
+
if state["uploaded_files"][i]["file_type"] == "multimodal":
|
|
103
|
+
# Read the Excel file
|
|
104
|
+
multimodal_df = pd.read_excel(
|
|
105
|
+
state["uploaded_files"][i]["file_path"], sheet_name=None
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# Check if the multimodal_df is empty
|
|
109
|
+
logger.log(logging.INFO, "Checking if multimodal_df is empty")
|
|
110
|
+
if len(multimodal_df) > 0:
|
|
111
|
+
# Prepare multimodal_df
|
|
112
|
+
logger.log(logging.INFO, "Preparing multimodal_df")
|
|
113
|
+
# Merge all obtained dataframes into a single dataframe
|
|
114
|
+
multimodal_df = pd.concat(multimodal_df).reset_index()
|
|
115
|
+
multimodal_df = self.loader.df.DataFrame(multimodal_df)
|
|
116
|
+
multimodal_df.drop(columns=["level_1"], inplace=True)
|
|
117
|
+
multimodal_df.rename(
|
|
118
|
+
columns={"level_0": "q_node_type", "name": "q_node_name"}, inplace=True
|
|
119
|
+
)
|
|
120
|
+
# Since an excel sheet name could not contain a `/`,
|
|
121
|
+
# but the node type can be 'gene/protein' as exists in the PrimeKG
|
|
122
|
+
multimodal_df["q_node_type"] = multimodal_df["q_node_type"].str.replace("-", "_")
|
|
123
|
+
|
|
124
|
+
return multimodal_df
|
|
125
|
+
|
|
126
|
+
def _query_milvus_collection(self, node_type, node_type_df, cfg_db):
|
|
127
|
+
"""Helper method to query Milvus collection for a specific node type."""
|
|
128
|
+
# Load the collection
|
|
129
|
+
collection = Collection(
|
|
130
|
+
name=f"{cfg_db.milvus_db.database_name}_nodes_{node_type.replace('/', '_')}"
|
|
131
|
+
)
|
|
132
|
+
collection.load()
|
|
133
|
+
|
|
134
|
+
# Query the collection with node names from multimodal_df
|
|
135
|
+
node_names_series = node_type_df["q_node_name"]
|
|
136
|
+
q_node_names = getattr(
|
|
137
|
+
node_names_series, "to_pandas", lambda series=node_names_series: series
|
|
138
|
+
)().tolist()
|
|
139
|
+
q_columns = [
|
|
140
|
+
"node_id",
|
|
141
|
+
"node_name",
|
|
142
|
+
"node_type",
|
|
143
|
+
"feat",
|
|
144
|
+
"feat_emb",
|
|
145
|
+
"desc",
|
|
146
|
+
"desc_emb",
|
|
147
|
+
]
|
|
148
|
+
res = collection.query(
|
|
149
|
+
expr=f"node_name IN [{','.join(f'"{name}"' for name in q_node_names)}]",
|
|
150
|
+
output_fields=q_columns,
|
|
151
|
+
)
|
|
152
|
+
# Convert the embeedings into floats
|
|
153
|
+
for r_ in res:
|
|
154
|
+
r_["feat_emb"] = [float(x) for x in r_["feat_emb"]]
|
|
155
|
+
r_["desc_emb"] = [float(x) for x in r_["desc_emb"]]
|
|
156
|
+
|
|
157
|
+
# Convert the result to a DataFrame
|
|
158
|
+
res_df = self.loader.df.DataFrame(res)[q_columns]
|
|
159
|
+
res_df["use_description"] = False
|
|
160
|
+
return res_df
|
|
161
|
+
|
|
162
|
+
async def _query_milvus_collection_async(
|
|
163
|
+
self, node_type, node_type_df, cfg_db, connection_manager
|
|
164
|
+
):
|
|
165
|
+
"""Helper method to query Milvus collection asynchronously for a specific node type."""
|
|
166
|
+
collection_name = f"{cfg_db.milvus_db.database_name}_nodes_{node_type.replace('/', '_')}"
|
|
167
|
+
|
|
168
|
+
# Query the collection with node names from multimodal_df
|
|
169
|
+
node_names_series = node_type_df["q_node_name"]
|
|
170
|
+
q_node_names = getattr(
|
|
171
|
+
node_names_series, "to_pandas", lambda series=node_names_series: series
|
|
172
|
+
)().tolist()
|
|
173
|
+
|
|
174
|
+
# Create filter expression for async query
|
|
175
|
+
node_names_str = ",".join(f'"{name}"' for name in q_node_names)
|
|
176
|
+
expr = f"node_name IN [{node_names_str}]"
|
|
177
|
+
|
|
178
|
+
q_columns = [
|
|
179
|
+
"node_id",
|
|
180
|
+
"node_name",
|
|
181
|
+
"node_type",
|
|
182
|
+
"feat",
|
|
183
|
+
"feat_emb",
|
|
184
|
+
"desc",
|
|
185
|
+
"desc_emb",
|
|
186
|
+
]
|
|
187
|
+
|
|
188
|
+
# Create query parameters and perform async query
|
|
189
|
+
query_params = QueryParams(
|
|
190
|
+
collection_name=collection_name, expr=expr, output_fields=q_columns
|
|
191
|
+
)
|
|
192
|
+
res = await connection_manager.async_query(query_params)
|
|
193
|
+
|
|
194
|
+
# Convert the embeddings into floats
|
|
195
|
+
for r_ in res:
|
|
196
|
+
r_["feat_emb"] = [float(x) for x in r_["feat_emb"]]
|
|
197
|
+
r_["desc_emb"] = [float(x) for x in r_["desc_emb"]]
|
|
198
|
+
|
|
199
|
+
# Convert the result to a DataFrame
|
|
200
|
+
res_df = (
|
|
201
|
+
self.loader.df.DataFrame(res)[q_columns]
|
|
202
|
+
if res
|
|
203
|
+
else self.loader.df.DataFrame(columns=q_columns)
|
|
204
|
+
)
|
|
205
|
+
res_df["use_description"] = False
|
|
206
|
+
return res_df
|
|
207
|
+
|
|
208
|
+
def _prepare_query_modalities(
|
|
209
|
+
self, prompt: dict, state: Annotated[dict, InjectedState], cfg_db: dict
|
|
210
|
+
):
|
|
211
|
+
"""
|
|
212
|
+
Prepare the modality-specific query for subgraph extraction.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
prompt: The dictionary containing the user prompt and embeddings.
|
|
216
|
+
state: The injected state for the tool.
|
|
217
|
+
cfg_db: The configuration dictionary for Milvus database.
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
A DataFrame containing the query embeddings and modalities.
|
|
221
|
+
"""
|
|
222
|
+
# Initialize dataframes
|
|
223
|
+
logger.log(logging.INFO, "Initializing dataframes")
|
|
224
|
+
query_df = []
|
|
225
|
+
prompt_df = self.loader.df.DataFrame(
|
|
226
|
+
{
|
|
227
|
+
"node_id": "user_prompt",
|
|
228
|
+
"node_name": "User Prompt",
|
|
229
|
+
"node_type": "prompt",
|
|
230
|
+
"feat": prompt["text"],
|
|
231
|
+
"feat_emb": prompt["emb"],
|
|
232
|
+
"desc": prompt["text"],
|
|
233
|
+
"desc_emb": prompt["emb"],
|
|
234
|
+
"use_description": True, # set to True for user prompt embedding
|
|
235
|
+
}
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
# Read multimodal files uploaded by the user
|
|
239
|
+
multimodal_df = self._read_multimodal_files(state)
|
|
240
|
+
|
|
241
|
+
# Check if the multimodal_df is empty
|
|
242
|
+
logger.log(logging.INFO, "Prepare query modalities")
|
|
243
|
+
if len(multimodal_df) > 0:
|
|
244
|
+
# Query the Milvus database for each node type in multimodal_df
|
|
245
|
+
logger.log(
|
|
246
|
+
logging.INFO,
|
|
247
|
+
"Querying Milvus database for each node type in multimodal_df",
|
|
248
|
+
)
|
|
249
|
+
for node_type, node_type_df in multimodal_df.groupby("q_node_type"):
|
|
250
|
+
print(f"Processing node type: {node_type}")
|
|
251
|
+
res_df = self._query_milvus_collection(node_type, node_type_df, cfg_db)
|
|
252
|
+
query_df.append(res_df)
|
|
253
|
+
|
|
254
|
+
# Concatenate all results into a single DataFrame
|
|
255
|
+
logger.log(logging.INFO, "Concatenating all results into a single DataFrame")
|
|
256
|
+
query_df = self.loader.df.concat(query_df, ignore_index=True)
|
|
257
|
+
|
|
258
|
+
# Update the state by adding the the selected node IDs
|
|
259
|
+
logger.log(logging.INFO, "Updating state with selected node IDs")
|
|
260
|
+
state["selections"] = (
|
|
261
|
+
getattr(query_df, "to_pandas", lambda: query_df)()
|
|
262
|
+
.groupby("node_type")["node_id"]
|
|
263
|
+
.apply(list)
|
|
264
|
+
.to_dict()
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
# Append a user prompt to the query dataframe
|
|
268
|
+
logger.log(logging.INFO, "Adding user prompt to query dataframe")
|
|
269
|
+
query_df = self.loader.df.concat([query_df, prompt_df]).reset_index(drop=True)
|
|
270
|
+
else:
|
|
271
|
+
# If no multimodal files are uploaded, use the prompt embeddings
|
|
272
|
+
query_df = prompt_df
|
|
273
|
+
|
|
274
|
+
return query_df
|
|
275
|
+
|
|
276
|
+
async def _prepare_query_modalities_async(
|
|
277
|
+
self,
|
|
278
|
+
prompt: dict,
|
|
279
|
+
state: Annotated[dict, InjectedState],
|
|
280
|
+
cfg_db: dict,
|
|
281
|
+
connection_manager,
|
|
282
|
+
):
|
|
283
|
+
"""
|
|
284
|
+
Prepare the modality-specific query for subgraph extraction asynchronously.
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
prompt: The dictionary containing the user prompt and embeddings
|
|
288
|
+
state: The injected state for the tool
|
|
289
|
+
cfg_db: The configuration dictionary for Milvus database
|
|
290
|
+
connection_manager: The MilvusConnectionManager instance
|
|
291
|
+
|
|
292
|
+
Returns:
|
|
293
|
+
A DataFrame containing the query embeddings and modalities
|
|
294
|
+
"""
|
|
295
|
+
# Initialize dataframes
|
|
296
|
+
logger.log(logging.INFO, "Initializing dataframes (async)")
|
|
297
|
+
query_df = []
|
|
298
|
+
prompt_df = self.loader.df.DataFrame(
|
|
299
|
+
{
|
|
300
|
+
"node_id": "user_prompt",
|
|
301
|
+
"node_name": "User Prompt",
|
|
302
|
+
"node_type": "prompt",
|
|
303
|
+
"feat": prompt["text"],
|
|
304
|
+
"feat_emb": prompt["emb"],
|
|
305
|
+
"desc": prompt["text"],
|
|
306
|
+
"desc_emb": prompt["emb"],
|
|
307
|
+
"use_description": True, # set to True for user prompt embedding
|
|
308
|
+
}
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
# Read multimodal files uploaded by the user
|
|
312
|
+
multimodal_df = self._read_multimodal_files(state)
|
|
313
|
+
|
|
314
|
+
# Check if the multimodal_df is empty
|
|
315
|
+
logger.log(logging.INFO, "Prepare query modalities (async)")
|
|
316
|
+
if len(multimodal_df) > 0:
|
|
317
|
+
# Create parallel tasks for querying each node type
|
|
318
|
+
logger.log(
|
|
319
|
+
logging.INFO,
|
|
320
|
+
"Querying Milvus database for each node type in multimodal_df (parallel)",
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
# Create async tasks for each node type
|
|
324
|
+
tasks = []
|
|
325
|
+
for node_type, node_type_df in multimodal_df.groupby("q_node_type"):
|
|
326
|
+
print(f"Processing node type: {node_type}")
|
|
327
|
+
task = self._query_milvus_collection_async(
|
|
328
|
+
node_type, node_type_df, cfg_db, connection_manager
|
|
329
|
+
)
|
|
330
|
+
tasks.append(task)
|
|
331
|
+
|
|
332
|
+
# Execute all queries in parallel using hybrid approach
|
|
333
|
+
if len(tasks) == 1:
|
|
334
|
+
# Single task, run directly
|
|
335
|
+
query_results = [await tasks[0]]
|
|
336
|
+
else:
|
|
337
|
+
# Multiple tasks, but use sequential execution to avoid event loop issues
|
|
338
|
+
query_results = []
|
|
339
|
+
for task in tasks:
|
|
340
|
+
result = await task
|
|
341
|
+
query_results.append(result)
|
|
342
|
+
|
|
343
|
+
query_df.extend(query_results)
|
|
344
|
+
|
|
345
|
+
# Concatenate all results into a single DataFrame
|
|
346
|
+
logger.log(logging.INFO, "Concatenating all results into a single DataFrame")
|
|
347
|
+
query_df = self.loader.df.concat(query_df, ignore_index=True)
|
|
348
|
+
|
|
349
|
+
# Update the state by adding the selected node IDs
|
|
350
|
+
logger.log(logging.INFO, "Updating state with selected node IDs")
|
|
351
|
+
state["selections"] = (
|
|
352
|
+
getattr(query_df, "to_pandas", lambda: query_df)()
|
|
353
|
+
.groupby("node_type")["node_id"]
|
|
354
|
+
.apply(list)
|
|
355
|
+
.to_dict()
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
# Append a user prompt to the query dataframe
|
|
359
|
+
logger.log(logging.INFO, "Adding user prompt to query dataframe")
|
|
360
|
+
query_df = self.loader.df.concat([query_df, prompt_df]).reset_index(drop=True)
|
|
361
|
+
else:
|
|
362
|
+
# If no multimodal files are uploaded, use the prompt embeddings
|
|
363
|
+
query_df = prompt_df
|
|
364
|
+
|
|
365
|
+
return query_df
|
|
366
|
+
|
|
367
|
+
def _perform_subgraph_extraction(
|
|
368
|
+
self,
|
|
369
|
+
state: Annotated[dict, InjectedState],
|
|
370
|
+
cfg: dict,
|
|
371
|
+
cfg_db: dict,
|
|
372
|
+
query_df: pd.DataFrame,
|
|
373
|
+
) -> dict:
|
|
374
|
+
"""
|
|
375
|
+
Perform multimodal subgraph extraction based on modal-specific embeddings.
|
|
376
|
+
|
|
377
|
+
Args:
|
|
378
|
+
state: The injected state for the tool.
|
|
379
|
+
cfg: The configuration dictionary.
|
|
380
|
+
cfg_db: The configuration dictionary for Milvus database.
|
|
381
|
+
query_df: The DataFrame containing the query embeddings and modalities.
|
|
382
|
+
|
|
383
|
+
Returns:
|
|
384
|
+
A dictionary containing the extracted subgraph with nodes and edges.
|
|
385
|
+
"""
|
|
386
|
+
# Initialize the subgraph dictionary
|
|
387
|
+
subgraphs = []
|
|
388
|
+
unified_subgraph = {"nodes": [], "edges": []}
|
|
389
|
+
# subgraphs = {}
|
|
390
|
+
# subgraphs["nodes"] = []
|
|
391
|
+
# subgraphs["edges"] = []
|
|
392
|
+
|
|
393
|
+
# Loop over query embeddings and modalities
|
|
394
|
+
for q in getattr(query_df, "to_pandas", lambda: query_df)().iterrows():
|
|
395
|
+
logger.log(logging.INFO, "===========================================")
|
|
396
|
+
logger.log(logging.INFO, "Processing query: %s", q[1]["node_name"])
|
|
397
|
+
# Prepare the PCSTPruning object and extract the subgraph
|
|
398
|
+
# Parameters were set in the configuration file obtained from Hydra
|
|
399
|
+
# start = datetime.datetime.now()
|
|
400
|
+
# Get dynamic metric type (overrides any config setting)
|
|
401
|
+
# Get dynamic metric type (overrides any config setting)
|
|
402
|
+
has_vector_processing = hasattr(cfg, "vector_processing")
|
|
403
|
+
if has_vector_processing:
|
|
404
|
+
dynamic_metrics_enabled = getattr(cfg.vector_processing, "dynamic_metrics", True)
|
|
405
|
+
else:
|
|
406
|
+
dynamic_metrics_enabled = False
|
|
407
|
+
if has_vector_processing and dynamic_metrics_enabled:
|
|
408
|
+
dynamic_metric_type = self.loader.metric_type
|
|
409
|
+
else:
|
|
410
|
+
dynamic_metric_type = getattr(cfg, "search_metric_type", self.loader.metric_type)
|
|
411
|
+
|
|
412
|
+
subgraph = MultimodalPCSTPruning(
|
|
413
|
+
topk=state["topk_nodes"],
|
|
414
|
+
topk_e=state["topk_edges"],
|
|
415
|
+
cost_e=cfg.cost_e,
|
|
416
|
+
c_const=cfg.c_const,
|
|
417
|
+
root=cfg.root,
|
|
418
|
+
num_clusters=cfg.num_clusters,
|
|
419
|
+
pruning=cfg.pruning,
|
|
420
|
+
verbosity_level=cfg.verbosity_level,
|
|
421
|
+
use_description=q[1]["use_description"],
|
|
422
|
+
metric_type=dynamic_metric_type, # Use dynamic or config metric type
|
|
423
|
+
loader=self.loader, # Pass the loader instance
|
|
424
|
+
).extract_subgraph(q[1]["desc_emb"], q[1]["feat_emb"], q[1]["node_type"], cfg_db)
|
|
425
|
+
|
|
426
|
+
# Append the extracted subgraph to the dictionary
|
|
427
|
+
unified_subgraph["nodes"].append(subgraph["nodes"].tolist())
|
|
428
|
+
unified_subgraph["edges"].append(subgraph["edges"].tolist())
|
|
429
|
+
subgraphs.append(
|
|
430
|
+
(
|
|
431
|
+
q[1]["node_name"],
|
|
432
|
+
subgraph["nodes"].tolist(),
|
|
433
|
+
subgraph["edges"].tolist(),
|
|
434
|
+
)
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
# end = datetime.datetime.now()
|
|
438
|
+
# logger.log(logging.INFO, "Subgraph extraction time: %s seconds",
|
|
439
|
+
# (end - start).total_seconds())
|
|
440
|
+
|
|
441
|
+
# Concatenate and get unique node and edge indices
|
|
442
|
+
nodes_arrays = [self.loader.py.array(list_) for list_ in unified_subgraph["nodes"]]
|
|
443
|
+
unified_subgraph["nodes"] = self.loader.py.unique(
|
|
444
|
+
self.loader.py.concatenate(nodes_arrays)
|
|
445
|
+
).tolist()
|
|
446
|
+
edges_arrays = [self.loader.py.array(list_) for list_ in unified_subgraph["edges"]]
|
|
447
|
+
unified_subgraph["edges"] = self.loader.py.unique(
|
|
448
|
+
self.loader.py.concatenate(edges_arrays)
|
|
449
|
+
).tolist()
|
|
450
|
+
|
|
451
|
+
# Convert the unified subgraph and subgraphs to DataFrames
|
|
452
|
+
unified_subgraph = self.loader.df.DataFrame(
|
|
453
|
+
[
|
|
454
|
+
(
|
|
455
|
+
"Unified Subgraph",
|
|
456
|
+
unified_subgraph["nodes"],
|
|
457
|
+
unified_subgraph["edges"],
|
|
458
|
+
)
|
|
459
|
+
],
|
|
460
|
+
columns=["name", "nodes", "edges"],
|
|
461
|
+
)
|
|
462
|
+
subgraphs = self.loader.df.DataFrame(subgraphs, columns=["name", "nodes", "edges"])
|
|
463
|
+
|
|
464
|
+
# Concatenate both DataFrames
|
|
465
|
+
subgraphs = self.loader.df.concat([unified_subgraph, subgraphs], ignore_index=True)
|
|
466
|
+
|
|
467
|
+
return subgraphs
|
|
468
|
+
|
|
469
|
+
async def _perform_subgraph_extraction_async(self, params: ExtractionParams) -> dict:
|
|
470
|
+
"""
|
|
471
|
+
Perform multimodal subgraph extraction based on modal-specific embeddings asynchronously.
|
|
472
|
+
|
|
473
|
+
Args:
|
|
474
|
+
state: The injected state for the tool
|
|
475
|
+
cfg: The configuration dictionary
|
|
476
|
+
cfg_db: The configuration dictionary for Milvus database
|
|
477
|
+
query_df: The DataFrame containing the query embeddings and modalities
|
|
478
|
+
connection_manager: The MilvusConnectionManager instance
|
|
479
|
+
|
|
480
|
+
Returns:
|
|
481
|
+
A dictionary containing the extracted subgraph with nodes and edges
|
|
482
|
+
"""
|
|
483
|
+
# Initialize the subgraph dictionary
|
|
484
|
+
subgraphs = []
|
|
485
|
+
unified_subgraph = {"nodes": [], "edges": []}
|
|
486
|
+
|
|
487
|
+
# Create parallel tasks for each query
|
|
488
|
+
tasks = []
|
|
489
|
+
query_info = []
|
|
490
|
+
|
|
491
|
+
for q in getattr(params.query_df, "to_pandas", lambda: params.query_df)().iterrows():
|
|
492
|
+
logger.log(logging.INFO, "===========================================")
|
|
493
|
+
logger.log(logging.INFO, "Processing query: %s", q[1]["node_name"])
|
|
494
|
+
|
|
495
|
+
# Store query info for later processing
|
|
496
|
+
query_info.append(q[1])
|
|
497
|
+
|
|
498
|
+
# Get dynamic metric type using helper method
|
|
499
|
+
dynamic_metric_type = self._get_dynamic_metric_type(params.cfg)
|
|
500
|
+
|
|
501
|
+
# Create PCST pruning instance using helper
|
|
502
|
+
pcst_instance = self._create_pcst_instance(params, q[1], dynamic_metric_type)
|
|
503
|
+
|
|
504
|
+
# Create async task for subgraph extraction
|
|
505
|
+
task = self._extract_single_subgraph_async(
|
|
506
|
+
pcst_instance, q[1], params.cfg_db, params.connection_manager
|
|
507
|
+
)
|
|
508
|
+
tasks.append(task)
|
|
509
|
+
|
|
510
|
+
# Execute all subgraph extractions sequentially to avoid event loop conflicts
|
|
511
|
+
subgraph_results = []
|
|
512
|
+
for i, task in enumerate(tasks):
|
|
513
|
+
logger.log(logging.INFO, "Processing subgraph %d/%d", i + 1, len(tasks))
|
|
514
|
+
result = await task
|
|
515
|
+
subgraph_results.append(result)
|
|
516
|
+
|
|
517
|
+
# Process results and finalize
|
|
518
|
+
self._process_subgraph_results(subgraph_results, query_info, unified_subgraph, subgraphs)
|
|
519
|
+
return self._finalize_subgraph_results(subgraphs, unified_subgraph)
|
|
520
|
+
|
|
521
|
+
def _process_subgraph_results(self, subgraph_results, query_info, unified_subgraph, subgraphs):
|
|
522
|
+
"""Process individual subgraph results."""
|
|
523
|
+
for i, subgraph in enumerate(subgraph_results):
|
|
524
|
+
query_row = query_info[i]
|
|
525
|
+
unified_subgraph["nodes"].append(subgraph["nodes"].tolist())
|
|
526
|
+
unified_subgraph["edges"].append(subgraph["edges"].tolist())
|
|
527
|
+
subgraphs.append(
|
|
528
|
+
(
|
|
529
|
+
query_row["node_name"],
|
|
530
|
+
subgraph["nodes"].tolist(),
|
|
531
|
+
subgraph["edges"].tolist(),
|
|
532
|
+
)
|
|
533
|
+
)
|
|
534
|
+
|
|
535
|
+
def _finalize_subgraph_results(self, subgraphs, unified_subgraph):
|
|
536
|
+
"""Process and finalize subgraph results into DataFrames."""
|
|
537
|
+
# Concatenate and get unique node and edge indices
|
|
538
|
+
nodes_arrays = [self.loader.py.array(list_) for list_ in unified_subgraph["nodes"]]
|
|
539
|
+
unified_subgraph["nodes"] = self.loader.py.unique(
|
|
540
|
+
self.loader.py.concatenate(nodes_arrays)
|
|
541
|
+
).tolist()
|
|
542
|
+
edges_arrays = [self.loader.py.array(list_) for list_ in unified_subgraph["edges"]]
|
|
543
|
+
unified_subgraph["edges"] = self.loader.py.unique(
|
|
544
|
+
self.loader.py.concatenate(edges_arrays)
|
|
545
|
+
).tolist()
|
|
546
|
+
|
|
547
|
+
# Convert the unified subgraph and subgraphs to DataFrames
|
|
548
|
+
unified_subgraph_df = self.loader.df.DataFrame(
|
|
549
|
+
[
|
|
550
|
+
(
|
|
551
|
+
"Unified Subgraph",
|
|
552
|
+
unified_subgraph["nodes"],
|
|
553
|
+
unified_subgraph["edges"],
|
|
554
|
+
)
|
|
555
|
+
],
|
|
556
|
+
columns=["name", "nodes", "edges"],
|
|
557
|
+
)
|
|
558
|
+
subgraphs_df = self.loader.df.DataFrame(subgraphs, columns=["name", "nodes", "edges"])
|
|
559
|
+
|
|
560
|
+
# Concatenate both DataFrames
|
|
561
|
+
return self.loader.df.concat([unified_subgraph_df, subgraphs_df], ignore_index=True)
|
|
562
|
+
|
|
563
|
+
async def _extract_single_subgraph_async(
|
|
564
|
+
self, pcst_instance, query_row, cfg_db, connection_manager
|
|
565
|
+
):
|
|
566
|
+
"""
|
|
567
|
+
Extract a single subgraph asynchronously using the new async methods.
|
|
568
|
+
"""
|
|
569
|
+
# Load data and compute prizes
|
|
570
|
+
edge_index, prizes, num_nodes = await self._load_subgraph_data(
|
|
571
|
+
pcst_instance, query_row, cfg_db, connection_manager
|
|
572
|
+
)
|
|
573
|
+
|
|
574
|
+
# Run PCST algorithm and get results
|
|
575
|
+
return self._run_pcst_algorithm(pcst_instance, edge_index, num_nodes, prizes)
|
|
576
|
+
|
|
577
|
+
async def _load_subgraph_data(self, pcst_instance, query_row, cfg_db, connection_manager):
|
|
578
|
+
"""Load edge index, compute prizes, and get node count."""
|
|
579
|
+
# Load edge index asynchronously
|
|
580
|
+
edge_index = await pcst_instance.load_edge_index_async(cfg_db, connection_manager)
|
|
581
|
+
|
|
582
|
+
# Compute prizes asynchronously
|
|
583
|
+
prizes = await pcst_instance.compute_prizes_async(
|
|
584
|
+
query_row["desc_emb"],
|
|
585
|
+
query_row["feat_emb"],
|
|
586
|
+
cfg_db,
|
|
587
|
+
query_row["node_type"],
|
|
588
|
+
)
|
|
589
|
+
|
|
590
|
+
# Get number of nodes
|
|
591
|
+
nodes_collection = f"{cfg_db.milvus_db.database_name}_nodes"
|
|
592
|
+
stats = await connection_manager.async_get_collection_stats(nodes_collection)
|
|
593
|
+
num_nodes = stats["num_entities"]
|
|
594
|
+
|
|
595
|
+
return edge_index, prizes, num_nodes
|
|
596
|
+
|
|
597
|
+
def _run_pcst_algorithm(self, pcst_instance, edge_index, num_nodes, prizes):
|
|
598
|
+
"""Run PCST algorithm and get subgraph results."""
|
|
599
|
+
# Compute costs in constructing the subgraph
|
|
600
|
+
edges_dict, prizes_final, costs, mapping = pcst_instance.compute_subgraph_costs(
|
|
601
|
+
edge_index, num_nodes, prizes
|
|
602
|
+
)
|
|
603
|
+
|
|
604
|
+
# Retrieve the subgraph using the PCST algorithm
|
|
605
|
+
result_vertices, result_edges = pcst_fast.pcst_fast(
|
|
606
|
+
edges_dict["edges"].tolist(),
|
|
607
|
+
prizes_final.tolist(),
|
|
608
|
+
costs.tolist(),
|
|
609
|
+
pcst_instance.root,
|
|
610
|
+
pcst_instance.num_clusters,
|
|
611
|
+
pcst_instance.pruning,
|
|
612
|
+
pcst_instance.verbosity_level,
|
|
613
|
+
)
|
|
614
|
+
|
|
615
|
+
# Get subgraph nodes and edges based on the PCST result
|
|
616
|
+
return pcst_instance.get_subgraph_nodes_edges(
|
|
617
|
+
num_nodes,
|
|
618
|
+
pcst_instance.loader.py.asarray(result_vertices),
|
|
619
|
+
{
|
|
620
|
+
"edges": pcst_instance.loader.py.asarray(result_edges),
|
|
621
|
+
"num_prior_edges": edges_dict["num_prior_edges"],
|
|
622
|
+
"edge_index": edge_index,
|
|
623
|
+
},
|
|
624
|
+
mapping,
|
|
625
|
+
)
|
|
626
|
+
|
|
627
|
+
def _run(
|
|
628
|
+
self,
|
|
629
|
+
tool_call_id: Annotated[str, InjectedToolCallId],
|
|
630
|
+
state: Annotated[dict, InjectedState],
|
|
631
|
+
prompt: str,
|
|
632
|
+
arg_data: ArgumentData = None,
|
|
633
|
+
) -> Command:
|
|
634
|
+
"""
|
|
635
|
+
Synchronous wrapper for the async _run_async method.
|
|
636
|
+
This maintains compatibility with LangGraph while using async operations internally.
|
|
637
|
+
"""
|
|
638
|
+
# concurrent.futures imported at top level
|
|
639
|
+
|
|
640
|
+
def run_in_thread():
|
|
641
|
+
"""Run async method in a new thread with its own event loop."""
|
|
642
|
+
# Create a new event loop for this thread
|
|
643
|
+
new_loop = asyncio.new_event_loop()
|
|
644
|
+
asyncio.set_event_loop(new_loop)
|
|
645
|
+
try:
|
|
646
|
+
result = new_loop.run_until_complete(
|
|
647
|
+
self._run_async(tool_call_id, state, prompt, arg_data)
|
|
648
|
+
)
|
|
649
|
+
return result
|
|
650
|
+
finally:
|
|
651
|
+
# Properly cleanup the event loop
|
|
652
|
+
new_loop.close()
|
|
653
|
+
asyncio.set_event_loop(None)
|
|
654
|
+
|
|
655
|
+
# Always use a separate thread to avoid event loop conflicts
|
|
656
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
|
657
|
+
future = executor.submit(run_in_thread)
|
|
658
|
+
return future.result()
|
|
659
|
+
|
|
660
|
+
def _prepare_final_subgraph(
|
|
661
|
+
self, state: Annotated[dict, InjectedState], subgraph: dict, cfg_db
|
|
662
|
+
) -> dict:
|
|
663
|
+
"""
|
|
664
|
+
Prepare the subgraph based on the extracted subgraph.
|
|
665
|
+
|
|
666
|
+
Args:
|
|
667
|
+
state: The injected state for the tool.
|
|
668
|
+
subgraph: The extracted subgraph.
|
|
669
|
+
cfg_db: The configuration dictionary for Milvus database.
|
|
670
|
+
|
|
671
|
+
Returns:
|
|
672
|
+
A dictionary containing the PyG graph, NetworkX graph, and textualized graph.
|
|
673
|
+
"""
|
|
674
|
+
# Convert the dict to a DataFrame
|
|
675
|
+
node_colors = {
|
|
676
|
+
n: cfg_db.node_colors_dict[k] for k, v in state["selections"].items() for n in v
|
|
677
|
+
}
|
|
678
|
+
color_df = self.loader.df.DataFrame(list(node_colors.items()), columns=["node_id", "color"])
|
|
679
|
+
# print(color_df)
|
|
680
|
+
|
|
681
|
+
# Prepare the subgraph dictionary
|
|
682
|
+
graph_dict = {"name": [], "nodes": [], "edges": [], "text": ""}
|
|
683
|
+
for sub in getattr(subgraph, "to_pandas", lambda: subgraph)().itertuples(index=False):
|
|
684
|
+
graph_nodes, graph_edges = self._process_subgraph_data(sub, cfg_db, color_df)
|
|
685
|
+
|
|
686
|
+
# Prepare lists for visualization
|
|
687
|
+
graph_dict["name"].append(sub.name)
|
|
688
|
+
graph_dict["nodes"].append(
|
|
689
|
+
[
|
|
690
|
+
(
|
|
691
|
+
row.node_id,
|
|
692
|
+
{
|
|
693
|
+
"hover": "Node Name : "
|
|
694
|
+
+ row.node_name
|
|
695
|
+
+ "\n"
|
|
696
|
+
+ "Node Type : "
|
|
697
|
+
+ row.node_type
|
|
698
|
+
+ "\n"
|
|
699
|
+
+ "Desc : "
|
|
700
|
+
+ row.desc,
|
|
701
|
+
"click": "$hover",
|
|
702
|
+
"color": row.color,
|
|
703
|
+
},
|
|
704
|
+
)
|
|
705
|
+
for row in getattr(
|
|
706
|
+
graph_nodes,
|
|
707
|
+
"to_pandas",
|
|
708
|
+
lambda graph_nodes=graph_nodes: graph_nodes,
|
|
709
|
+
)().itertuples(index=False)
|
|
710
|
+
]
|
|
711
|
+
)
|
|
712
|
+
graph_dict["edges"].append(
|
|
713
|
+
[
|
|
714
|
+
(row.head_id, row.tail_id, {"label": tuple(row.edge_type)})
|
|
715
|
+
for row in getattr(
|
|
716
|
+
graph_edges,
|
|
717
|
+
"to_pandas",
|
|
718
|
+
lambda graph_edges=graph_edges: graph_edges,
|
|
719
|
+
)().itertuples(index=False)
|
|
720
|
+
]
|
|
721
|
+
)
|
|
722
|
+
|
|
723
|
+
# Prepare the textualized subgraph
|
|
724
|
+
if sub.name == "Unified Subgraph":
|
|
725
|
+
graph_nodes = graph_nodes[["node_id", "desc"]]
|
|
726
|
+
graph_nodes.rename(columns={"desc": "node_attr"}, inplace=True)
|
|
727
|
+
graph_edges = graph_edges[["head_id", "edge_type", "tail_id"]]
|
|
728
|
+
nodes_pandas = getattr(
|
|
729
|
+
graph_nodes,
|
|
730
|
+
"to_pandas",
|
|
731
|
+
lambda graph_nodes=graph_nodes: graph_nodes,
|
|
732
|
+
)()
|
|
733
|
+
nodes_csv = nodes_pandas.to_csv(index=False)
|
|
734
|
+
edges_pandas = getattr(
|
|
735
|
+
graph_edges,
|
|
736
|
+
"to_pandas",
|
|
737
|
+
lambda graph_edges=graph_edges: graph_edges,
|
|
738
|
+
)()
|
|
739
|
+
edges_csv = edges_pandas.to_csv(index=False)
|
|
740
|
+
graph_dict["text"] = nodes_csv + "\n" + edges_csv
|
|
741
|
+
|
|
742
|
+
return graph_dict
|
|
743
|
+
|
|
744
|
+
def _process_subgraph_data(self, sub, cfg_db, color_df):
|
|
745
|
+
"""Helper method to process individual subgraph data."""
|
|
746
|
+
print(f"Processing subgraph: {sub.name}")
|
|
747
|
+
print("---")
|
|
748
|
+
print(sub.nodes)
|
|
749
|
+
print("---")
|
|
750
|
+
print(sub.edges)
|
|
751
|
+
print("---")
|
|
752
|
+
|
|
753
|
+
# Prepare graph dataframes - Nodes
|
|
754
|
+
coll_name = f"{cfg_db.milvus_db.database_name}_nodes"
|
|
755
|
+
node_coll = Collection(name=coll_name)
|
|
756
|
+
node_coll.load()
|
|
757
|
+
graph_nodes = node_coll.query(
|
|
758
|
+
expr=f"node_index IN [{','.join(f'{n}' for n in sub.nodes)}]",
|
|
759
|
+
output_fields=["node_id", "node_name", "node_type", "desc"],
|
|
760
|
+
)
|
|
761
|
+
graph_nodes = self.loader.df.DataFrame(graph_nodes)
|
|
762
|
+
graph_nodes.drop(columns=["node_index"], inplace=True)
|
|
763
|
+
if not color_df.empty:
|
|
764
|
+
graph_nodes = graph_nodes.merge(color_df, on="node_id", how="left")
|
|
765
|
+
else:
|
|
766
|
+
graph_nodes["color"] = "black"
|
|
767
|
+
graph_nodes["color"] = graph_nodes["color"].fillna("black")
|
|
768
|
+
|
|
769
|
+
# Edges
|
|
770
|
+
coll_name = f"{cfg_db.milvus_db.database_name}_edges"
|
|
771
|
+
edge_coll = Collection(name=coll_name)
|
|
772
|
+
edge_coll.load()
|
|
773
|
+
graph_edges = edge_coll.query(
|
|
774
|
+
expr=f"triplet_index IN [{','.join(f'{e}' for e in sub.edges)}]",
|
|
775
|
+
output_fields=["head_id", "tail_id", "edge_type"],
|
|
776
|
+
)
|
|
777
|
+
graph_edges = self.loader.df.DataFrame(graph_edges)
|
|
778
|
+
graph_edges.drop(columns=["triplet_index"], inplace=True)
|
|
779
|
+
graph_edges["edge_type"] = graph_edges["edge_type"].str.split("|")
|
|
780
|
+
|
|
781
|
+
return graph_nodes, graph_edges
|
|
782
|
+
|
|
783
|
+
def _get_dynamic_metric_type(self, cfg: dict) -> str:
|
|
784
|
+
"""Helper method to get dynamic metric type."""
|
|
785
|
+
has_vector_processing = hasattr(cfg, "vector_processing")
|
|
786
|
+
if has_vector_processing:
|
|
787
|
+
dynamic_metrics_enabled = getattr(cfg.vector_processing, "dynamic_metrics", True)
|
|
788
|
+
else:
|
|
789
|
+
dynamic_metrics_enabled = False
|
|
790
|
+
if has_vector_processing and dynamic_metrics_enabled:
|
|
791
|
+
return self.loader.metric_type
|
|
792
|
+
return getattr(cfg, "search_metric_type", self.loader.metric_type)
|
|
793
|
+
|
|
794
|
+
def _create_pcst_instance(
|
|
795
|
+
self, params: ExtractionParams, query_row: dict, dynamic_metric_type: str
|
|
796
|
+
) -> MultimodalPCSTPruning:
|
|
797
|
+
"""Helper method to create PCST pruning instance."""
|
|
798
|
+
return MultimodalPCSTPruning(
|
|
799
|
+
topk=params.state["topk_nodes"],
|
|
800
|
+
topk_e=params.state["topk_edges"],
|
|
801
|
+
cost_e=params.cfg.cost_e,
|
|
802
|
+
c_const=params.cfg.c_const,
|
|
803
|
+
root=params.cfg.root,
|
|
804
|
+
num_clusters=params.cfg.num_clusters,
|
|
805
|
+
pruning=params.cfg.pruning,
|
|
806
|
+
verbosity_level=params.cfg.verbosity_level,
|
|
807
|
+
use_description=query_row["use_description"],
|
|
808
|
+
metric_type=dynamic_metric_type,
|
|
809
|
+
loader=self.loader,
|
|
810
|
+
)
|
|
811
|
+
|
|
812
|
+
def normalize_vector(self, v: list) -> list:
|
|
813
|
+
"""
|
|
814
|
+
Normalize a vector using appropriate library (CuPy for GPU, NumPy for CPU).
|
|
815
|
+
|
|
816
|
+
Args:
|
|
817
|
+
v : Vector to normalize.
|
|
818
|
+
|
|
819
|
+
Returns:
|
|
820
|
+
Normalized vector.
|
|
821
|
+
"""
|
|
822
|
+
if self.loader.normalize_vectors:
|
|
823
|
+
# GPU mode: normalize the vector
|
|
824
|
+
v_array = self.loader.py.asarray(v)
|
|
825
|
+
norm = self.loader.py.linalg.norm(v_array)
|
|
826
|
+
return (v_array / norm).tolist()
|
|
827
|
+
# CPU mode: return as-is for COSINE similarity
|
|
828
|
+
return v
|
|
829
|
+
|
|
830
|
+
async def _run_async(
|
|
831
|
+
self,
|
|
832
|
+
tool_call_id: Annotated[str, InjectedToolCallId],
|
|
833
|
+
state: Annotated[dict, InjectedState],
|
|
834
|
+
prompt: str,
|
|
835
|
+
arg_data: ArgumentData = None,
|
|
836
|
+
) -> Command:
|
|
837
|
+
"""
|
|
838
|
+
Run the subgraph extraction tool.
|
|
839
|
+
|
|
840
|
+
Args:
|
|
841
|
+
tool_call_id: The tool call ID for the tool.
|
|
842
|
+
state: Injected state for the tool.
|
|
843
|
+
prompt: The prompt to interact with the backend.
|
|
844
|
+
arg_data (ArgumentData): The argument data.
|
|
845
|
+
|
|
846
|
+
Returns:
|
|
847
|
+
Command: The command to be executed.
|
|
848
|
+
"""
|
|
849
|
+
logger.log(logging.INFO, "Invoking subgraph_extraction tool")
|
|
850
|
+
|
|
851
|
+
# Load hydra configuration
|
|
852
|
+
with hydra.initialize(version_base=None, config_path="../configs"):
|
|
853
|
+
cfg = hydra.compose(
|
|
854
|
+
config_name="config",
|
|
855
|
+
overrides=["tools/multimodal_subgraph_extraction=default"],
|
|
856
|
+
)
|
|
857
|
+
cfg = cfg.tools.multimodal_subgraph_extraction
|
|
858
|
+
|
|
859
|
+
# Load database configuration separately
|
|
860
|
+
with hydra.initialize(version_base=None, config_path="../configs"):
|
|
861
|
+
cfg_all = hydra.compose(config_name="config")
|
|
862
|
+
cfg_db = cfg_all.utils.database.milvus
|
|
863
|
+
|
|
864
|
+
# Establish Milvus connection using singleton connection manager
|
|
865
|
+
logger.log(logging.INFO, "Getting Milvus connection manager (singleton)")
|
|
866
|
+
connection_manager = MilvusConnectionManager(cfg_db)
|
|
867
|
+
try:
|
|
868
|
+
connection_manager.ensure_connection()
|
|
869
|
+
logger.log(logging.INFO, "Milvus connection established successfully")
|
|
870
|
+
|
|
871
|
+
# Log connection info
|
|
872
|
+
conn_info = connection_manager.get_connection_info()
|
|
873
|
+
logger.log(logging.INFO, "Connected to database: %s", conn_info.get("database"))
|
|
874
|
+
logger.log(
|
|
875
|
+
logging.INFO,
|
|
876
|
+
"Connection healthy: %s",
|
|
877
|
+
connection_manager.test_connection(),
|
|
878
|
+
)
|
|
879
|
+
except Exception as e:
|
|
880
|
+
logger.error("Failed to establish Milvus connection: %s", str(e))
|
|
881
|
+
raise RuntimeError(f"Cannot connect to Milvus database: {str(e)}") from e
|
|
882
|
+
|
|
883
|
+
# Prepare the query embeddings and modalities (async)
|
|
884
|
+
logger.log(logging.INFO, "_prepare_query_modalities_async")
|
|
885
|
+
query_df = await self._prepare_query_modalities_async(
|
|
886
|
+
{
|
|
887
|
+
"text": prompt,
|
|
888
|
+
"emb": [self.normalize_vector(state["embedding_model"].embed_query(prompt))],
|
|
889
|
+
},
|
|
890
|
+
state,
|
|
891
|
+
cfg_db,
|
|
892
|
+
connection_manager,
|
|
893
|
+
)
|
|
894
|
+
|
|
895
|
+
# Perform subgraph extraction (async)
|
|
896
|
+
logger.log(logging.INFO, "_perform_subgraph_extraction_async")
|
|
897
|
+
extraction_params = ExtractionParams(
|
|
898
|
+
state=state,
|
|
899
|
+
cfg=cfg,
|
|
900
|
+
cfg_db=cfg_db,
|
|
901
|
+
query_df=query_df,
|
|
902
|
+
connection_manager=connection_manager,
|
|
903
|
+
)
|
|
904
|
+
subgraphs = await self._perform_subgraph_extraction_async(extraction_params)
|
|
905
|
+
|
|
906
|
+
# Prepare subgraph as a NetworkX graph and textualized graph
|
|
907
|
+
logger.log(logging.INFO, "_prepare_final_subgraph")
|
|
908
|
+
logger.log(logging.INFO, "Subgraphs extracted: %s", len(subgraphs))
|
|
909
|
+
# start = datetime.datetime.now()
|
|
910
|
+
final_subgraph = self._prepare_final_subgraph(state, subgraphs, cfg_db)
|
|
911
|
+
# end = datetime.datetime.now()
|
|
912
|
+
# logger.log(logging.INFO, "_prepare_final_subgraph time: %s seconds",
|
|
913
|
+
# (end - start).total_seconds())
|
|
914
|
+
|
|
915
|
+
# Create final result and return command
|
|
916
|
+
return self._create_extraction_result(tool_call_id, state, final_subgraph, arg_data)
|
|
917
|
+
|
|
918
|
+
def _create_extraction_result(self, tool_call_id, state, final_subgraph, arg_data):
|
|
919
|
+
"""Create the final extraction result and command."""
|
|
920
|
+
# Prepare the dictionary of extracted graph
|
|
921
|
+
logger.log(logging.INFO, "dic_extracted_graph")
|
|
922
|
+
dic_extracted_graph = {
|
|
923
|
+
"name": arg_data.extraction_name,
|
|
924
|
+
"tool_call_id": tool_call_id,
|
|
925
|
+
"graph_source": state["dic_source_graph"][0]["name"],
|
|
926
|
+
"topk_nodes": state["topk_nodes"],
|
|
927
|
+
"topk_edges": state["topk_edges"],
|
|
928
|
+
"graph_dict": {
|
|
929
|
+
"name": final_subgraph["name"],
|
|
930
|
+
"nodes": final_subgraph["nodes"],
|
|
931
|
+
"edges": final_subgraph["edges"],
|
|
932
|
+
},
|
|
933
|
+
"graph_text": final_subgraph["text"],
|
|
934
|
+
"graph_summary": None,
|
|
935
|
+
}
|
|
936
|
+
|
|
937
|
+
# Debug logging
|
|
938
|
+
logger.info(
|
|
939
|
+
"Created dic_extracted_graph with keys: %s",
|
|
940
|
+
list(dic_extracted_graph.keys()),
|
|
941
|
+
)
|
|
942
|
+
logger.info(
|
|
943
|
+
"Graph dict structure - name count: %d, nodes count: %d, edges count: %d",
|
|
944
|
+
len(dic_extracted_graph["graph_dict"]["name"]),
|
|
945
|
+
len(dic_extracted_graph["graph_dict"]["nodes"]),
|
|
946
|
+
len(dic_extracted_graph["graph_dict"]["edges"]),
|
|
947
|
+
)
|
|
948
|
+
|
|
949
|
+
# Create success message
|
|
950
|
+
success_message = (
|
|
951
|
+
f"Successfully extracted subgraph '{arg_data.extraction_name}' "
|
|
952
|
+
f"with {len(final_subgraph['name'])} graph(s). The subgraph contains "
|
|
953
|
+
f"{sum(len(nodes) for nodes in final_subgraph['nodes'])} nodes and "
|
|
954
|
+
f"{sum(len(edges) for edges in final_subgraph['edges'])} edges. "
|
|
955
|
+
"The extracted subgraph has been stored and is ready for "
|
|
956
|
+
"visualization and analysis."
|
|
957
|
+
)
|
|
958
|
+
|
|
959
|
+
# Return the command with updated state
|
|
960
|
+
return Command(
|
|
961
|
+
update={"dic_extracted_graph": [dic_extracted_graph]}
|
|
962
|
+
| {
|
|
963
|
+
"messages": [ToolMessage(content=success_message, tool_call_id=tool_call_id)],
|
|
964
|
+
}
|
|
965
|
+
)
|