aiagents4pharma 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiagents4pharma/__init__.py +11 -0
- aiagents4pharma/talk2aiagents4pharma/.dockerignore +13 -0
- aiagents4pharma/talk2aiagents4pharma/Dockerfile +133 -0
- aiagents4pharma/talk2aiagents4pharma/README.md +1 -0
- aiagents4pharma/talk2aiagents4pharma/__init__.py +5 -0
- aiagents4pharma/talk2aiagents4pharma/agents/__init__.py +6 -0
- aiagents4pharma/talk2aiagents4pharma/agents/main_agent.py +70 -0
- aiagents4pharma/talk2aiagents4pharma/configs/__init__.py +5 -0
- aiagents4pharma/talk2aiagents4pharma/configs/agents/__init__.py +5 -0
- aiagents4pharma/talk2aiagents4pharma/configs/agents/main_agent/default.yaml +29 -0
- aiagents4pharma/talk2aiagents4pharma/configs/app/__init__.py +0 -0
- aiagents4pharma/talk2aiagents4pharma/configs/app/frontend/__init__.py +0 -0
- aiagents4pharma/talk2aiagents4pharma/configs/app/frontend/default.yaml +102 -0
- aiagents4pharma/talk2aiagents4pharma/configs/config.yaml +4 -0
- aiagents4pharma/talk2aiagents4pharma/docker-compose/cpu/.env.example +23 -0
- aiagents4pharma/talk2aiagents4pharma/docker-compose/cpu/docker-compose.yml +93 -0
- aiagents4pharma/talk2aiagents4pharma/docker-compose/gpu/.env.example +23 -0
- aiagents4pharma/talk2aiagents4pharma/docker-compose/gpu/docker-compose.yml +108 -0
- aiagents4pharma/talk2aiagents4pharma/install.md +154 -0
- aiagents4pharma/talk2aiagents4pharma/states/__init__.py +5 -0
- aiagents4pharma/talk2aiagents4pharma/states/state_talk2aiagents4pharma.py +18 -0
- aiagents4pharma/talk2aiagents4pharma/tests/__init__.py +3 -0
- aiagents4pharma/talk2aiagents4pharma/tests/test_main_agent.py +312 -0
- aiagents4pharma/talk2biomodels/.dockerignore +13 -0
- aiagents4pharma/talk2biomodels/Dockerfile +104 -0
- aiagents4pharma/talk2biomodels/README.md +1 -0
- aiagents4pharma/talk2biomodels/__init__.py +5 -0
- aiagents4pharma/talk2biomodels/agents/__init__.py +6 -0
- aiagents4pharma/talk2biomodels/agents/t2b_agent.py +104 -0
- aiagents4pharma/talk2biomodels/api/__init__.py +5 -0
- aiagents4pharma/talk2biomodels/api/ols.py +75 -0
- aiagents4pharma/talk2biomodels/api/uniprot.py +36 -0
- aiagents4pharma/talk2biomodels/configs/__init__.py +5 -0
- aiagents4pharma/talk2biomodels/configs/agents/__init__.py +5 -0
- aiagents4pharma/talk2biomodels/configs/agents/t2b_agent/__init__.py +3 -0
- aiagents4pharma/talk2biomodels/configs/agents/t2b_agent/default.yaml +14 -0
- aiagents4pharma/talk2biomodels/configs/app/__init__.py +0 -0
- aiagents4pharma/talk2biomodels/configs/app/frontend/__init__.py +0 -0
- aiagents4pharma/talk2biomodels/configs/app/frontend/default.yaml +72 -0
- aiagents4pharma/talk2biomodels/configs/config.yaml +7 -0
- aiagents4pharma/talk2biomodels/configs/tools/__init__.py +5 -0
- aiagents4pharma/talk2biomodels/configs/tools/ask_question/__init__.py +3 -0
- aiagents4pharma/talk2biomodels/configs/tools/ask_question/default.yaml +30 -0
- aiagents4pharma/talk2biomodels/configs/tools/custom_plotter/__init__.py +3 -0
- aiagents4pharma/talk2biomodels/configs/tools/custom_plotter/default.yaml +8 -0
- aiagents4pharma/talk2biomodels/configs/tools/get_annotation/__init__.py +3 -0
- aiagents4pharma/talk2biomodels/configs/tools/get_annotation/default.yaml +8 -0
- aiagents4pharma/talk2biomodels/install.md +63 -0
- aiagents4pharma/talk2biomodels/models/__init__.py +5 -0
- aiagents4pharma/talk2biomodels/models/basico_model.py +125 -0
- aiagents4pharma/talk2biomodels/models/sys_bio_model.py +60 -0
- aiagents4pharma/talk2biomodels/states/__init__.py +6 -0
- aiagents4pharma/talk2biomodels/states/state_talk2biomodels.py +49 -0
- aiagents4pharma/talk2biomodels/tests/BIOMD0000000449_url.xml +1585 -0
- aiagents4pharma/talk2biomodels/tests/__init__.py +3 -0
- aiagents4pharma/talk2biomodels/tests/article_on_model_537.pdf +0 -0
- aiagents4pharma/talk2biomodels/tests/test_api.py +31 -0
- aiagents4pharma/talk2biomodels/tests/test_ask_question.py +42 -0
- aiagents4pharma/talk2biomodels/tests/test_basico_model.py +67 -0
- aiagents4pharma/talk2biomodels/tests/test_get_annotation.py +190 -0
- aiagents4pharma/talk2biomodels/tests/test_getmodelinfo.py +92 -0
- aiagents4pharma/talk2biomodels/tests/test_integration.py +116 -0
- aiagents4pharma/talk2biomodels/tests/test_load_biomodel.py +35 -0
- aiagents4pharma/talk2biomodels/tests/test_param_scan.py +71 -0
- aiagents4pharma/talk2biomodels/tests/test_query_article.py +184 -0
- aiagents4pharma/talk2biomodels/tests/test_save_model.py +47 -0
- aiagents4pharma/talk2biomodels/tests/test_search_models.py +35 -0
- aiagents4pharma/talk2biomodels/tests/test_simulate_model.py +44 -0
- aiagents4pharma/talk2biomodels/tests/test_steady_state.py +86 -0
- aiagents4pharma/talk2biomodels/tests/test_sys_bio_model.py +67 -0
- aiagents4pharma/talk2biomodels/tools/__init__.py +17 -0
- aiagents4pharma/talk2biomodels/tools/ask_question.py +125 -0
- aiagents4pharma/talk2biomodels/tools/custom_plotter.py +165 -0
- aiagents4pharma/talk2biomodels/tools/get_annotation.py +342 -0
- aiagents4pharma/talk2biomodels/tools/get_modelinfo.py +159 -0
- aiagents4pharma/talk2biomodels/tools/load_arguments.py +134 -0
- aiagents4pharma/talk2biomodels/tools/load_biomodel.py +44 -0
- aiagents4pharma/talk2biomodels/tools/parameter_scan.py +310 -0
- aiagents4pharma/talk2biomodels/tools/query_article.py +64 -0
- aiagents4pharma/talk2biomodels/tools/save_model.py +98 -0
- aiagents4pharma/talk2biomodels/tools/search_models.py +96 -0
- aiagents4pharma/talk2biomodels/tools/simulate_model.py +137 -0
- aiagents4pharma/talk2biomodels/tools/steady_state.py +187 -0
- aiagents4pharma/talk2biomodels/tools/utils.py +23 -0
- aiagents4pharma/talk2cells/README.md +1 -0
- aiagents4pharma/talk2cells/__init__.py +5 -0
- aiagents4pharma/talk2cells/agents/__init__.py +6 -0
- aiagents4pharma/talk2cells/agents/scp_agent.py +87 -0
- aiagents4pharma/talk2cells/states/__init__.py +6 -0
- aiagents4pharma/talk2cells/states/state_talk2cells.py +15 -0
- aiagents4pharma/talk2cells/tests/scp_agent/test_scp_agent.py +22 -0
- aiagents4pharma/talk2cells/tools/__init__.py +6 -0
- aiagents4pharma/talk2cells/tools/scp_agent/__init__.py +6 -0
- aiagents4pharma/talk2cells/tools/scp_agent/display_studies.py +27 -0
- aiagents4pharma/talk2cells/tools/scp_agent/search_studies.py +78 -0
- aiagents4pharma/talk2knowledgegraphs/.dockerignore +13 -0
- aiagents4pharma/talk2knowledgegraphs/Dockerfile +131 -0
- aiagents4pharma/talk2knowledgegraphs/README.md +1 -0
- aiagents4pharma/talk2knowledgegraphs/__init__.py +5 -0
- aiagents4pharma/talk2knowledgegraphs/agents/__init__.py +5 -0
- aiagents4pharma/talk2knowledgegraphs/agents/t2kg_agent.py +99 -0
- aiagents4pharma/talk2knowledgegraphs/configs/__init__.py +5 -0
- aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/default.yaml +62 -0
- aiagents4pharma/talk2knowledgegraphs/configs/app/__init__.py +5 -0
- aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/default.yaml +79 -0
- aiagents4pharma/talk2knowledgegraphs/configs/config.yaml +13 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/__init__.py +5 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/default.yaml +24 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/multimodal_subgraph_extraction/__init__.py +0 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/multimodal_subgraph_extraction/default.yaml +33 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/default.yaml +43 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/default.yaml +9 -0
- aiagents4pharma/talk2knowledgegraphs/configs/utils/database/milvus/__init__.py +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/utils/database/milvus/default.yaml +61 -0
- aiagents4pharma/talk2knowledgegraphs/configs/utils/enrichments/ols_terms/default.yaml +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/utils/enrichments/reactome_pathways/default.yaml +3 -0
- aiagents4pharma/talk2knowledgegraphs/configs/utils/enrichments/uniprot_proteins/default.yaml +6 -0
- aiagents4pharma/talk2knowledgegraphs/configs/utils/pubchem_utils/default.yaml +5 -0
- aiagents4pharma/talk2knowledgegraphs/datasets/__init__.py +5 -0
- aiagents4pharma/talk2knowledgegraphs/datasets/biobridge_primekg.py +607 -0
- aiagents4pharma/talk2knowledgegraphs/datasets/dataset.py +25 -0
- aiagents4pharma/talk2knowledgegraphs/datasets/primekg.py +212 -0
- aiagents4pharma/talk2knowledgegraphs/datasets/starkqa_primekg.py +210 -0
- aiagents4pharma/talk2knowledgegraphs/docker-compose/cpu/.env.example +23 -0
- aiagents4pharma/talk2knowledgegraphs/docker-compose/cpu/docker-compose.yml +93 -0
- aiagents4pharma/talk2knowledgegraphs/docker-compose/gpu/.env.example +23 -0
- aiagents4pharma/talk2knowledgegraphs/docker-compose/gpu/docker-compose.yml +108 -0
- aiagents4pharma/talk2knowledgegraphs/entrypoint.sh +180 -0
- aiagents4pharma/talk2knowledgegraphs/install.md +165 -0
- aiagents4pharma/talk2knowledgegraphs/milvus_data_dump.py +886 -0
- aiagents4pharma/talk2knowledgegraphs/states/__init__.py +5 -0
- aiagents4pharma/talk2knowledgegraphs/states/state_talk2knowledgegraphs.py +40 -0
- aiagents4pharma/talk2knowledgegraphs/tests/__init__.py +0 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_agents_t2kg_agent.py +318 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_biobridge_primekg.py +248 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_dataset.py +33 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_primekg.py +86 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_starkqa_primekg.py +125 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_graphrag_reasoning.py +257 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_milvus_multimodal_subgraph_extraction.py +1444 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_multimodal_subgraph_extraction.py +159 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_extraction.py +152 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_summarization.py +201 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_database_milvus_connection_manager.py +812 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_embeddings.py +51 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_huggingface.py +49 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_nim_molmim.py +59 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_ollama.py +63 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_sentencetransformer.py +47 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_enrichments.py +40 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ollama.py +94 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ols.py +70 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_pubchem.py +45 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_reactome.py +44 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_uniprot.py +48 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_extractions_milvus_multimodal_pcst.py +759 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_kg_utils.py +78 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_pubchem_utils.py +123 -0
- aiagents4pharma/talk2knowledgegraphs/tools/__init__.py +11 -0
- aiagents4pharma/talk2knowledgegraphs/tools/graphrag_reasoning.py +138 -0
- aiagents4pharma/talk2knowledgegraphs/tools/load_arguments.py +22 -0
- aiagents4pharma/talk2knowledgegraphs/tools/milvus_multimodal_subgraph_extraction.py +965 -0
- aiagents4pharma/talk2knowledgegraphs/tools/multimodal_subgraph_extraction.py +374 -0
- aiagents4pharma/talk2knowledgegraphs/tools/subgraph_extraction.py +291 -0
- aiagents4pharma/talk2knowledgegraphs/tools/subgraph_summarization.py +123 -0
- aiagents4pharma/talk2knowledgegraphs/utils/__init__.py +5 -0
- aiagents4pharma/talk2knowledgegraphs/utils/database/__init__.py +5 -0
- aiagents4pharma/talk2knowledgegraphs/utils/database/milvus_connection_manager.py +586 -0
- aiagents4pharma/talk2knowledgegraphs/utils/embeddings/__init__.py +5 -0
- aiagents4pharma/talk2knowledgegraphs/utils/embeddings/embeddings.py +81 -0
- aiagents4pharma/talk2knowledgegraphs/utils/embeddings/huggingface.py +111 -0
- aiagents4pharma/talk2knowledgegraphs/utils/embeddings/nim_molmim.py +54 -0
- aiagents4pharma/talk2knowledgegraphs/utils/embeddings/ollama.py +87 -0
- aiagents4pharma/talk2knowledgegraphs/utils/embeddings/sentence_transformer.py +73 -0
- aiagents4pharma/talk2knowledgegraphs/utils/enrichments/__init__.py +12 -0
- aiagents4pharma/talk2knowledgegraphs/utils/enrichments/enrichments.py +37 -0
- aiagents4pharma/talk2knowledgegraphs/utils/enrichments/ollama.py +129 -0
- aiagents4pharma/talk2knowledgegraphs/utils/enrichments/ols_terms.py +89 -0
- aiagents4pharma/talk2knowledgegraphs/utils/enrichments/pubchem_strings.py +78 -0
- aiagents4pharma/talk2knowledgegraphs/utils/enrichments/reactome_pathways.py +71 -0
- aiagents4pharma/talk2knowledgegraphs/utils/enrichments/uniprot_proteins.py +98 -0
- aiagents4pharma/talk2knowledgegraphs/utils/extractions/__init__.py +5 -0
- aiagents4pharma/talk2knowledgegraphs/utils/extractions/milvus_multimodal_pcst.py +762 -0
- aiagents4pharma/talk2knowledgegraphs/utils/extractions/multimodal_pcst.py +298 -0
- aiagents4pharma/talk2knowledgegraphs/utils/extractions/pcst.py +229 -0
- aiagents4pharma/talk2knowledgegraphs/utils/kg_utils.py +67 -0
- aiagents4pharma/talk2knowledgegraphs/utils/pubchem_utils.py +104 -0
- aiagents4pharma/talk2scholars/.dockerignore +13 -0
- aiagents4pharma/talk2scholars/Dockerfile +104 -0
- aiagents4pharma/talk2scholars/README.md +1 -0
- aiagents4pharma/talk2scholars/__init__.py +7 -0
- aiagents4pharma/talk2scholars/agents/__init__.py +13 -0
- aiagents4pharma/talk2scholars/agents/main_agent.py +89 -0
- aiagents4pharma/talk2scholars/agents/paper_download_agent.py +96 -0
- aiagents4pharma/talk2scholars/agents/pdf_agent.py +101 -0
- aiagents4pharma/talk2scholars/agents/s2_agent.py +135 -0
- aiagents4pharma/talk2scholars/agents/zotero_agent.py +127 -0
- aiagents4pharma/talk2scholars/configs/__init__.py +7 -0
- aiagents4pharma/talk2scholars/configs/agents/__init__.py +7 -0
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/__init__.py +7 -0
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/main_agent/__init__.py +3 -0
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/main_agent/default.yaml +52 -0
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/paper_download_agent/__init__.py +3 -0
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/paper_download_agent/default.yaml +19 -0
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/pdf_agent/__init__.py +3 -0
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/pdf_agent/default.yaml +19 -0
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/s2_agent/__init__.py +3 -0
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/s2_agent/default.yaml +44 -0
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/zotero_agent/__init__.py +3 -0
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/zotero_agent/default.yaml +19 -0
- aiagents4pharma/talk2scholars/configs/app/__init__.py +7 -0
- aiagents4pharma/talk2scholars/configs/app/frontend/__init__.py +3 -0
- aiagents4pharma/talk2scholars/configs/app/frontend/default.yaml +72 -0
- aiagents4pharma/talk2scholars/configs/config.yaml +16 -0
- aiagents4pharma/talk2scholars/configs/tools/__init__.py +21 -0
- aiagents4pharma/talk2scholars/configs/tools/multi_paper_recommendation/__init__.py +3 -0
- aiagents4pharma/talk2scholars/configs/tools/multi_paper_recommendation/default.yaml +26 -0
- aiagents4pharma/talk2scholars/configs/tools/paper_download/__init__.py +3 -0
- aiagents4pharma/talk2scholars/configs/tools/paper_download/default.yaml +124 -0
- aiagents4pharma/talk2scholars/configs/tools/question_and_answer/__init__.py +3 -0
- aiagents4pharma/talk2scholars/configs/tools/question_and_answer/default.yaml +62 -0
- aiagents4pharma/talk2scholars/configs/tools/retrieve_semantic_scholar_paper_id/__init__.py +3 -0
- aiagents4pharma/talk2scholars/configs/tools/retrieve_semantic_scholar_paper_id/default.yaml +12 -0
- aiagents4pharma/talk2scholars/configs/tools/search/__init__.py +3 -0
- aiagents4pharma/talk2scholars/configs/tools/search/default.yaml +26 -0
- aiagents4pharma/talk2scholars/configs/tools/single_paper_recommendation/__init__.py +3 -0
- aiagents4pharma/talk2scholars/configs/tools/single_paper_recommendation/default.yaml +26 -0
- aiagents4pharma/talk2scholars/configs/tools/zotero_read/__init__.py +3 -0
- aiagents4pharma/talk2scholars/configs/tools/zotero_read/default.yaml +57 -0
- aiagents4pharma/talk2scholars/configs/tools/zotero_write/__inti__.py +3 -0
- aiagents4pharma/talk2scholars/configs/tools/zotero_write/default.yaml +55 -0
- aiagents4pharma/talk2scholars/docker-compose/cpu/.env.example +21 -0
- aiagents4pharma/talk2scholars/docker-compose/cpu/docker-compose.yml +90 -0
- aiagents4pharma/talk2scholars/docker-compose/gpu/.env.example +21 -0
- aiagents4pharma/talk2scholars/docker-compose/gpu/docker-compose.yml +105 -0
- aiagents4pharma/talk2scholars/install.md +122 -0
- aiagents4pharma/talk2scholars/state/__init__.py +7 -0
- aiagents4pharma/talk2scholars/state/state_talk2scholars.py +98 -0
- aiagents4pharma/talk2scholars/tests/__init__.py +3 -0
- aiagents4pharma/talk2scholars/tests/test_agents_main_agent.py +256 -0
- aiagents4pharma/talk2scholars/tests/test_agents_paper_agents_download_agent.py +139 -0
- aiagents4pharma/talk2scholars/tests/test_agents_pdf_agent.py +114 -0
- aiagents4pharma/talk2scholars/tests/test_agents_s2_agent.py +198 -0
- aiagents4pharma/talk2scholars/tests/test_agents_zotero_agent.py +160 -0
- aiagents4pharma/talk2scholars/tests/test_s2_tools_display_dataframe.py +91 -0
- aiagents4pharma/talk2scholars/tests/test_s2_tools_query_dataframe.py +191 -0
- aiagents4pharma/talk2scholars/tests/test_states_state.py +38 -0
- aiagents4pharma/talk2scholars/tests/test_tools_paper_downloader.py +507 -0
- aiagents4pharma/talk2scholars/tests/test_tools_question_and_answer_tool.py +105 -0
- aiagents4pharma/talk2scholars/tests/test_tools_s2_multi.py +307 -0
- aiagents4pharma/talk2scholars/tests/test_tools_s2_retrieve.py +67 -0
- aiagents4pharma/talk2scholars/tests/test_tools_s2_search.py +286 -0
- aiagents4pharma/talk2scholars/tests/test_tools_s2_single.py +298 -0
- aiagents4pharma/talk2scholars/tests/test_utils_arxiv_downloader.py +469 -0
- aiagents4pharma/talk2scholars/tests/test_utils_base_paper_downloader.py +598 -0
- aiagents4pharma/talk2scholars/tests/test_utils_biorxiv_downloader.py +669 -0
- aiagents4pharma/talk2scholars/tests/test_utils_medrxiv_downloader.py +500 -0
- aiagents4pharma/talk2scholars/tests/test_utils_nvidia_nim_reranker.py +117 -0
- aiagents4pharma/talk2scholars/tests/test_utils_pdf_answer_formatter.py +67 -0
- aiagents4pharma/talk2scholars/tests/test_utils_pdf_batch_processor.py +92 -0
- aiagents4pharma/talk2scholars/tests/test_utils_pdf_collection_manager.py +173 -0
- aiagents4pharma/talk2scholars/tests/test_utils_pdf_document_processor.py +68 -0
- aiagents4pharma/talk2scholars/tests/test_utils_pdf_generate_answer.py +72 -0
- aiagents4pharma/talk2scholars/tests/test_utils_pdf_gpu_detection.py +129 -0
- aiagents4pharma/talk2scholars/tests/test_utils_pdf_paper_loader.py +116 -0
- aiagents4pharma/talk2scholars/tests/test_utils_pdf_rag_pipeline.py +88 -0
- aiagents4pharma/talk2scholars/tests/test_utils_pdf_retrieve_chunks.py +190 -0
- aiagents4pharma/talk2scholars/tests/test_utils_pdf_singleton_manager.py +159 -0
- aiagents4pharma/talk2scholars/tests/test_utils_pdf_vector_normalization.py +121 -0
- aiagents4pharma/talk2scholars/tests/test_utils_pdf_vector_store.py +406 -0
- aiagents4pharma/talk2scholars/tests/test_utils_pubmed_downloader.py +1007 -0
- aiagents4pharma/talk2scholars/tests/test_utils_read_helper_utils.py +106 -0
- aiagents4pharma/talk2scholars/tests/test_utils_s2_utils_ext_ids.py +403 -0
- aiagents4pharma/talk2scholars/tests/test_utils_tool_helper_utils.py +85 -0
- aiagents4pharma/talk2scholars/tests/test_utils_zotero_human_in_the_loop.py +266 -0
- aiagents4pharma/talk2scholars/tests/test_utils_zotero_path.py +496 -0
- aiagents4pharma/talk2scholars/tests/test_utils_zotero_pdf_downloader_utils.py +46 -0
- aiagents4pharma/talk2scholars/tests/test_utils_zotero_read.py +743 -0
- aiagents4pharma/talk2scholars/tests/test_utils_zotero_write.py +151 -0
- aiagents4pharma/talk2scholars/tools/__init__.py +9 -0
- aiagents4pharma/talk2scholars/tools/paper_download/__init__.py +12 -0
- aiagents4pharma/talk2scholars/tools/paper_download/paper_downloader.py +442 -0
- aiagents4pharma/talk2scholars/tools/paper_download/utils/__init__.py +22 -0
- aiagents4pharma/talk2scholars/tools/paper_download/utils/arxiv_downloader.py +207 -0
- aiagents4pharma/talk2scholars/tools/paper_download/utils/base_paper_downloader.py +336 -0
- aiagents4pharma/talk2scholars/tools/paper_download/utils/biorxiv_downloader.py +313 -0
- aiagents4pharma/talk2scholars/tools/paper_download/utils/medrxiv_downloader.py +196 -0
- aiagents4pharma/talk2scholars/tools/paper_download/utils/pubmed_downloader.py +323 -0
- aiagents4pharma/talk2scholars/tools/pdf/__init__.py +7 -0
- aiagents4pharma/talk2scholars/tools/pdf/question_and_answer.py +170 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/__init__.py +37 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/answer_formatter.py +62 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/batch_processor.py +198 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/collection_manager.py +172 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/document_processor.py +76 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/generate_answer.py +97 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/get_vectorstore.py +59 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/gpu_detection.py +150 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/nvidia_nim_reranker.py +97 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/paper_loader.py +123 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/rag_pipeline.py +113 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/retrieve_chunks.py +197 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/singleton_manager.py +140 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/tool_helper.py +86 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/vector_normalization.py +150 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/vector_store.py +327 -0
- aiagents4pharma/talk2scholars/tools/s2/__init__.py +21 -0
- aiagents4pharma/talk2scholars/tools/s2/display_dataframe.py +110 -0
- aiagents4pharma/talk2scholars/tools/s2/multi_paper_rec.py +111 -0
- aiagents4pharma/talk2scholars/tools/s2/query_dataframe.py +233 -0
- aiagents4pharma/talk2scholars/tools/s2/retrieve_semantic_scholar_paper_id.py +128 -0
- aiagents4pharma/talk2scholars/tools/s2/search.py +101 -0
- aiagents4pharma/talk2scholars/tools/s2/single_paper_rec.py +102 -0
- aiagents4pharma/talk2scholars/tools/s2/utils/__init__.py +5 -0
- aiagents4pharma/talk2scholars/tools/s2/utils/multi_helper.py +223 -0
- aiagents4pharma/talk2scholars/tools/s2/utils/search_helper.py +205 -0
- aiagents4pharma/talk2scholars/tools/s2/utils/single_helper.py +216 -0
- aiagents4pharma/talk2scholars/tools/zotero/__init__.py +7 -0
- aiagents4pharma/talk2scholars/tools/zotero/utils/__init__.py +7 -0
- aiagents4pharma/talk2scholars/tools/zotero/utils/read_helper.py +270 -0
- aiagents4pharma/talk2scholars/tools/zotero/utils/review_helper.py +74 -0
- aiagents4pharma/talk2scholars/tools/zotero/utils/write_helper.py +194 -0
- aiagents4pharma/talk2scholars/tools/zotero/utils/zotero_path.py +180 -0
- aiagents4pharma/talk2scholars/tools/zotero/utils/zotero_pdf_downloader.py +133 -0
- aiagents4pharma/talk2scholars/tools/zotero/zotero_read.py +105 -0
- aiagents4pharma/talk2scholars/tools/zotero/zotero_review.py +162 -0
- aiagents4pharma/talk2scholars/tools/zotero/zotero_write.py +91 -0
- aiagents4pharma-0.0.0.dist-info/METADATA +335 -0
- aiagents4pharma-0.0.0.dist-info/RECORD +336 -0
- aiagents4pharma-0.0.0.dist-info/WHEEL +4 -0
- aiagents4pharma-0.0.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,762 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Exctraction of multimodal subgraph using Prize-Collecting Steiner Tree (PCST) algorithm.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import logging
|
|
7
|
+
import platform
|
|
8
|
+
import subprocess
|
|
9
|
+
from typing import NamedTuple
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import pandas as pd
|
|
13
|
+
import pcst_fast
|
|
14
|
+
from pymilvus import Collection
|
|
15
|
+
|
|
16
|
+
try:
|
|
17
|
+
import cudf # type: ignore
|
|
18
|
+
import cupy as cp # type: ignore
|
|
19
|
+
|
|
20
|
+
CUDF_AVAILABLE = True
|
|
21
|
+
except ImportError:
|
|
22
|
+
CUDF_AVAILABLE = False
|
|
23
|
+
cudf = None
|
|
24
|
+
cp = None
|
|
25
|
+
|
|
26
|
+
# Initialize logger
|
|
27
|
+
logging.basicConfig(level=logging.INFO)
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class SystemDetector:
|
|
32
|
+
"""Detect system capabilities and choose appropriate libraries."""
|
|
33
|
+
|
|
34
|
+
def __init__(self):
|
|
35
|
+
self.os_type = platform.system().lower() # 'windows', 'linux', 'darwin'
|
|
36
|
+
self.architecture = platform.machine().lower() # 'x86_64', 'arm64', etc.
|
|
37
|
+
self.has_nvidia_gpu = self._detect_nvidia_gpu()
|
|
38
|
+
self.use_gpu = self.has_nvidia_gpu and self.os_type != "darwin" # No CUDA on macOS
|
|
39
|
+
|
|
40
|
+
logger.info("System Detection Results:")
|
|
41
|
+
logger.info(" OS: %s", self.os_type)
|
|
42
|
+
logger.info(" Architecture: %s", self.architecture)
|
|
43
|
+
logger.info(" NVIDIA GPU detected: %s", self.has_nvidia_gpu)
|
|
44
|
+
logger.info(" Will use GPU acceleration: %s", self.use_gpu)
|
|
45
|
+
|
|
46
|
+
def _detect_nvidia_gpu(self) -> bool:
|
|
47
|
+
"""Detect if NVIDIA GPU is available."""
|
|
48
|
+
try:
|
|
49
|
+
# Try nvidia-smi command
|
|
50
|
+
result = subprocess.run(
|
|
51
|
+
["nvidia-smi"], capture_output=True, text=True, timeout=10, check=False
|
|
52
|
+
)
|
|
53
|
+
return result.returncode == 0
|
|
54
|
+
except (
|
|
55
|
+
subprocess.TimeoutExpired,
|
|
56
|
+
FileNotFoundError,
|
|
57
|
+
subprocess.SubprocessError,
|
|
58
|
+
):
|
|
59
|
+
return False
|
|
60
|
+
|
|
61
|
+
def get_system_info(self) -> dict:
|
|
62
|
+
"""Get comprehensive system information."""
|
|
63
|
+
return {
|
|
64
|
+
"os_type": self.os_type,
|
|
65
|
+
"architecture": self.architecture,
|
|
66
|
+
"has_nvidia_gpu": self.has_nvidia_gpu,
|
|
67
|
+
"use_gpu": self.use_gpu,
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
def is_gpu_compatible(self) -> bool:
|
|
71
|
+
"""Check if the system is compatible with GPU acceleration."""
|
|
72
|
+
return self.has_nvidia_gpu and self.os_type != "darwin"
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class DynamicLibraryLoader:
|
|
76
|
+
"""Dynamically load libraries based on system capabilities."""
|
|
77
|
+
|
|
78
|
+
def __init__(self, detector: SystemDetector):
|
|
79
|
+
self.detector = detector
|
|
80
|
+
self.use_gpu = detector.use_gpu
|
|
81
|
+
|
|
82
|
+
# Initialize attributes that will be set later
|
|
83
|
+
self.py = None
|
|
84
|
+
self.df = None
|
|
85
|
+
self.pd = None
|
|
86
|
+
self.np = None
|
|
87
|
+
self.cudf = None
|
|
88
|
+
self.cp = None
|
|
89
|
+
|
|
90
|
+
# Import libraries based on system capabilities
|
|
91
|
+
self._import_libraries()
|
|
92
|
+
|
|
93
|
+
# Dynamic settings based on hardware
|
|
94
|
+
self.normalize_vectors = self.use_gpu # Only normalize for GPU
|
|
95
|
+
self.metric_type = "IP" if self.use_gpu else "COSINE"
|
|
96
|
+
|
|
97
|
+
logger.info("Library Configuration:")
|
|
98
|
+
logger.info(" Using GPU acceleration: %s", self.use_gpu)
|
|
99
|
+
logger.info(" Vector normalization: %s", self.normalize_vectors)
|
|
100
|
+
logger.info(" Metric type: %s", self.metric_type)
|
|
101
|
+
|
|
102
|
+
def _import_libraries(self):
|
|
103
|
+
"""Dynamically import libraries based on system capabilities."""
|
|
104
|
+
# Set base libraries
|
|
105
|
+
self.pd = pd
|
|
106
|
+
self.np = np
|
|
107
|
+
|
|
108
|
+
# Conditionally import GPU libraries
|
|
109
|
+
if self.detector.use_gpu:
|
|
110
|
+
if CUDF_AVAILABLE:
|
|
111
|
+
self.cudf = cudf
|
|
112
|
+
self.cp = cp
|
|
113
|
+
self.py = cp # Use cupy for array operations
|
|
114
|
+
self.df = cudf # Use cudf for dataframes
|
|
115
|
+
logger.info("Successfully imported GPU libraries (cudf, cupy)")
|
|
116
|
+
else:
|
|
117
|
+
logger.error("cudf or cupy not found. Falling back to CPU mode.")
|
|
118
|
+
self.detector.use_gpu = False
|
|
119
|
+
self.use_gpu = False
|
|
120
|
+
self._setup_cpu_mode()
|
|
121
|
+
else:
|
|
122
|
+
self._setup_cpu_mode()
|
|
123
|
+
|
|
124
|
+
def _setup_cpu_mode(self):
|
|
125
|
+
"""Setup CPU mode with numpy and pandas."""
|
|
126
|
+
self.py = self.np # Use numpy for array operations
|
|
127
|
+
self.df = self.pd # Use pandas for dataframes
|
|
128
|
+
self.normalize_vectors = False
|
|
129
|
+
self.metric_type = "COSINE"
|
|
130
|
+
logger.info("Using CPU mode with numpy and pandas")
|
|
131
|
+
|
|
132
|
+
def normalize_matrix(self, matrix, axis: int = 1):
|
|
133
|
+
"""Normalize matrix using appropriate library."""
|
|
134
|
+
if not self.normalize_vectors:
|
|
135
|
+
return matrix
|
|
136
|
+
|
|
137
|
+
if self.use_gpu:
|
|
138
|
+
# Use cupy for GPU
|
|
139
|
+
matrix_cp = self.cp.asarray(matrix).astype(self.cp.float32)
|
|
140
|
+
norms = self.cp.linalg.norm(matrix_cp, axis=axis, keepdims=True)
|
|
141
|
+
return matrix_cp / norms
|
|
142
|
+
# CPU mode doesn't normalize for COSINE similarity
|
|
143
|
+
return matrix
|
|
144
|
+
|
|
145
|
+
def to_list(self, data):
|
|
146
|
+
"""Convert data to list format."""
|
|
147
|
+
if hasattr(data, "tolist"):
|
|
148
|
+
return data.tolist()
|
|
149
|
+
if hasattr(data, "to_arrow"):
|
|
150
|
+
return data.to_arrow().to_pylist()
|
|
151
|
+
return list(data)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class MultimodalPCSTPruning(NamedTuple):
|
|
155
|
+
"""
|
|
156
|
+
Prize-Collecting Steiner Tree (PCST) pruning algorithm implementation inspired by G-Retriever
|
|
157
|
+
(He et al., 'G-Retriever: Retrieval-Augmented Generation for Textual Graph Understanding and
|
|
158
|
+
Question Answering', NeurIPS 2024) paper.
|
|
159
|
+
https://arxiv.org/abs/2402.07630
|
|
160
|
+
https://github.com/XiaoxinHe/G-Retriever/blob/main/src/dataset/utils/retrieval.py
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
topk: The number of top nodes to consider.
|
|
164
|
+
topk_e: The number of top edges to consider.
|
|
165
|
+
cost_e: The cost of the edges.
|
|
166
|
+
c_const: The constant value for the cost of the edges computation.
|
|
167
|
+
root: The root node of the subgraph, -1 for unrooted.
|
|
168
|
+
num_clusters: The number of clusters.
|
|
169
|
+
pruning: The pruning strategy to use.
|
|
170
|
+
verbosity_level: The verbosity level.
|
|
171
|
+
use_description: Whether to use description embeddings.
|
|
172
|
+
metric_type: The similarity metric type (dynamic based on hardware).
|
|
173
|
+
loader: The dynamic library loader instance.
|
|
174
|
+
"""
|
|
175
|
+
|
|
176
|
+
topk: int = 3
|
|
177
|
+
topk_e: int = 3
|
|
178
|
+
cost_e: float = 0.5
|
|
179
|
+
c_const: float = 0.01
|
|
180
|
+
root: int = -1
|
|
181
|
+
num_clusters: int = 1
|
|
182
|
+
pruning: str = "gw"
|
|
183
|
+
verbosity_level: int = 0
|
|
184
|
+
use_description: bool = False
|
|
185
|
+
metric_type: str = None # Will be set dynamically
|
|
186
|
+
loader: DynamicLibraryLoader = None
|
|
187
|
+
|
|
188
|
+
def prepare_collections(self, cfg: dict, modality: str) -> dict:
|
|
189
|
+
"""
|
|
190
|
+
Prepare the collections for nodes, node-type specific nodes, and edges in Milvus.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
cfg: The configuration dictionary containing the Milvus setup.
|
|
194
|
+
modality: The modality to use for the subgraph extraction.
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
A dictionary containing the collections of nodes, node-type specific nodes, and edges.
|
|
198
|
+
"""
|
|
199
|
+
# Initialize the collections dictionary
|
|
200
|
+
colls = {}
|
|
201
|
+
|
|
202
|
+
# Load the collection for nodes
|
|
203
|
+
colls["nodes"] = Collection(name=f"{cfg.milvus_db.database_name}_nodes")
|
|
204
|
+
|
|
205
|
+
if modality != "prompt":
|
|
206
|
+
# Load the collection for the specific node type
|
|
207
|
+
colls["nodes_type"] = Collection(
|
|
208
|
+
f"{cfg.milvus_db.database_name}_nodes_{modality.replace('/', '_')}"
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
# Load the collection for edges
|
|
212
|
+
colls["edges"] = Collection(name=f"{cfg.milvus_db.database_name}_edges")
|
|
213
|
+
|
|
214
|
+
# Load the collections
|
|
215
|
+
for coll in colls.values():
|
|
216
|
+
coll.load()
|
|
217
|
+
|
|
218
|
+
return colls
|
|
219
|
+
|
|
220
|
+
async def load_edge_index_async(self, cfg: dict, _connection_manager=None) -> np.ndarray:
|
|
221
|
+
"""
|
|
222
|
+
Load edge index using hybrid async/sync approach to avoid event loop issues.
|
|
223
|
+
|
|
224
|
+
This method queries the edges collection to get head_index and tail_index,
|
|
225
|
+
eliminating the need for pickle caching and reducing memory usage.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
cfg: The configuration dictionary containing the Milvus setup.
|
|
229
|
+
_connection_manager: Unused parameter for interface compatibility.
|
|
230
|
+
|
|
231
|
+
Returns:
|
|
232
|
+
numpy.ndarray: Edge index array with shape [2, num_edges]
|
|
233
|
+
"""
|
|
234
|
+
logger.log(logging.INFO, "Loading edge index from Milvus collection (hybrid)")
|
|
235
|
+
|
|
236
|
+
def load_edges_sync():
|
|
237
|
+
"""Load edges synchronously to avoid event loop issues."""
|
|
238
|
+
|
|
239
|
+
collection_name = f"{cfg.milvus_db.database_name}_edges"
|
|
240
|
+
edges_collection = Collection(name=collection_name)
|
|
241
|
+
edges_collection.load()
|
|
242
|
+
|
|
243
|
+
# Query all edges in batches
|
|
244
|
+
batch_size = getattr(cfg.milvus_db, "query_batch_size", 10000)
|
|
245
|
+
total_entities = edges_collection.num_entities
|
|
246
|
+
logger.log(logging.INFO, "Total edges to process: %d", total_entities)
|
|
247
|
+
|
|
248
|
+
head_list = []
|
|
249
|
+
tail_list = []
|
|
250
|
+
|
|
251
|
+
for start in range(0, total_entities, batch_size):
|
|
252
|
+
end = min(start + batch_size, total_entities)
|
|
253
|
+
logger.debug("Processing edge batch: %d to %d", start, end)
|
|
254
|
+
|
|
255
|
+
batch = edges_collection.query(
|
|
256
|
+
expr=f"triplet_index >= {start} and triplet_index < {end}",
|
|
257
|
+
output_fields=["head_index", "tail_index"],
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
head_list.extend([r["head_index"] for r in batch])
|
|
261
|
+
tail_list.extend([r["tail_index"] for r in batch])
|
|
262
|
+
|
|
263
|
+
# Convert to numpy array format expected by PCST
|
|
264
|
+
edge_index = self.loader.py.array([head_list, tail_list])
|
|
265
|
+
logger.log(
|
|
266
|
+
logging.INFO,
|
|
267
|
+
"Edge index loaded (hybrid): shape %s",
|
|
268
|
+
str(edge_index.shape),
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
return edge_index
|
|
272
|
+
|
|
273
|
+
# Run in thread to avoid event loop conflicts
|
|
274
|
+
return await asyncio.to_thread(load_edges_sync)
|
|
275
|
+
|
|
276
|
+
def load_edge_index(self, cfg: dict) -> np.ndarray:
|
|
277
|
+
"""
|
|
278
|
+
Load edge index synchronously from Milvus collection.
|
|
279
|
+
|
|
280
|
+
This method queries the edges collection to get head_index and tail_index.
|
|
281
|
+
|
|
282
|
+
Args:
|
|
283
|
+
cfg: The configuration dictionary containing the Milvus setup.
|
|
284
|
+
|
|
285
|
+
Returns:
|
|
286
|
+
numpy.ndarray: Edge index array with shape [2, num_edges]
|
|
287
|
+
"""
|
|
288
|
+
logger.log(logging.INFO, "Loading edge index from Milvus collection (sync)")
|
|
289
|
+
|
|
290
|
+
collection_name = f"{cfg.milvus_db.database_name}_edges"
|
|
291
|
+
edges_collection = Collection(name=collection_name)
|
|
292
|
+
edges_collection.load()
|
|
293
|
+
|
|
294
|
+
# Query all edges in batches
|
|
295
|
+
batch_size = getattr(cfg.milvus_db, "query_batch_size", 10000)
|
|
296
|
+
total_entities = edges_collection.num_entities
|
|
297
|
+
logger.log(logging.INFO, "Total edges to process: %d", total_entities)
|
|
298
|
+
|
|
299
|
+
head_list = []
|
|
300
|
+
tail_list = []
|
|
301
|
+
|
|
302
|
+
for start in range(0, total_entities, batch_size):
|
|
303
|
+
end = min(start + batch_size, total_entities)
|
|
304
|
+
logger.debug("Processing edge batch: %d to %d", start, end)
|
|
305
|
+
|
|
306
|
+
batch = edges_collection.query(
|
|
307
|
+
expr=f"triplet_index >= {start} and triplet_index < {end}",
|
|
308
|
+
output_fields=["head_index", "tail_index"],
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
head_list.extend([r["head_index"] for r in batch])
|
|
312
|
+
tail_list.extend([r["tail_index"] for r in batch])
|
|
313
|
+
|
|
314
|
+
# Convert to numpy array format expected by PCST
|
|
315
|
+
edge_index = self.loader.py.array([head_list, tail_list])
|
|
316
|
+
logger.log(
|
|
317
|
+
logging.INFO,
|
|
318
|
+
"Edge index loaded (sync): shape %s",
|
|
319
|
+
str(edge_index.shape),
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
return edge_index
|
|
323
|
+
|
|
324
|
+
def _compute_node_prizes(self, query_emb: list, colls: dict) -> dict:
|
|
325
|
+
"""
|
|
326
|
+
Compute the node prizes based on the similarity between the query and nodes.
|
|
327
|
+
|
|
328
|
+
Args:
|
|
329
|
+
query_emb: The query embedding. This can be an embedding of
|
|
330
|
+
a prompt, sequence, or any other feature to be used for the subgraph extraction.
|
|
331
|
+
colls: The collections of nodes, node-type specific nodes, and edges in Milvus.
|
|
332
|
+
|
|
333
|
+
Returns:
|
|
334
|
+
The prizes of the nodes.
|
|
335
|
+
"""
|
|
336
|
+
# Initialize several variables
|
|
337
|
+
topk = min(self.topk, colls["nodes"].num_entities)
|
|
338
|
+
n_prizes = self.loader.py.zeros(colls["nodes"].num_entities, dtype=self.loader.py.float32)
|
|
339
|
+
|
|
340
|
+
# Get the actual metric type to use
|
|
341
|
+
actual_metric_type = self.metric_type or self.loader.metric_type
|
|
342
|
+
|
|
343
|
+
# Calculate similarity for text features and update the score
|
|
344
|
+
if self.use_description:
|
|
345
|
+
# Search the collection with the text embedding
|
|
346
|
+
res = colls["nodes"].search(
|
|
347
|
+
data=[query_emb],
|
|
348
|
+
anns_field="desc_emb",
|
|
349
|
+
param={"metric_type": actual_metric_type},
|
|
350
|
+
limit=topk,
|
|
351
|
+
output_fields=["node_id"],
|
|
352
|
+
)
|
|
353
|
+
else:
|
|
354
|
+
# Search the collection with the query embedding
|
|
355
|
+
res = colls["nodes_type"].search(
|
|
356
|
+
data=[query_emb],
|
|
357
|
+
anns_field="feat_emb",
|
|
358
|
+
param={"metric_type": actual_metric_type},
|
|
359
|
+
limit=topk,
|
|
360
|
+
output_fields=["node_id"],
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
# Update the prizes based on the search results
|
|
364
|
+
n_prizes[[r.id for r in res[0]]] = self.loader.py.arange(topk, 0, -1).astype(
|
|
365
|
+
self.loader.py.float32
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
return n_prizes
|
|
369
|
+
|
|
370
|
+
async def _compute_node_prizes_async(
|
|
371
|
+
self,
|
|
372
|
+
query_emb: list,
|
|
373
|
+
collection_name: str,
|
|
374
|
+
connection_manager,
|
|
375
|
+
use_description: bool = False,
|
|
376
|
+
) -> dict:
|
|
377
|
+
"""
|
|
378
|
+
Compute the node prizes asynchronously using connection manager.
|
|
379
|
+
|
|
380
|
+
Args:
|
|
381
|
+
query_emb: The query embedding
|
|
382
|
+
collection_name: Name of the collection to search
|
|
383
|
+
connection_manager: The MilvusConnectionManager instance
|
|
384
|
+
use_description: Whether to use description embeddings
|
|
385
|
+
|
|
386
|
+
Returns:
|
|
387
|
+
The prizes of the nodes
|
|
388
|
+
"""
|
|
389
|
+
# Get collection stats for initialization
|
|
390
|
+
stats = await connection_manager.async_get_collection_stats(collection_name)
|
|
391
|
+
num_entities = stats["num_entities"]
|
|
392
|
+
|
|
393
|
+
# Initialize prizes array
|
|
394
|
+
topk = min(self.topk, num_entities)
|
|
395
|
+
n_prizes = self.loader.py.zeros(num_entities, dtype=self.loader.py.float32)
|
|
396
|
+
|
|
397
|
+
# Get the actual metric type to use
|
|
398
|
+
actual_metric_type = self.metric_type or self.loader.metric_type
|
|
399
|
+
|
|
400
|
+
# Determine search field based on use_description
|
|
401
|
+
anns_field = "desc_emb" if use_description else "feat_emb"
|
|
402
|
+
|
|
403
|
+
# Perform async search
|
|
404
|
+
results = await connection_manager.async_search(
|
|
405
|
+
collection_name=collection_name,
|
|
406
|
+
data=[query_emb],
|
|
407
|
+
anns_field=anns_field,
|
|
408
|
+
param={"metric_type": actual_metric_type},
|
|
409
|
+
limit=topk,
|
|
410
|
+
output_fields=["node_id"],
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
# Update the prizes based on the search results
|
|
414
|
+
if results and len(results) > 0:
|
|
415
|
+
result_ids = [hit["id"] for hit in results[0]]
|
|
416
|
+
n_prizes[result_ids] = self.loader.py.arange(topk, 0, -1).astype(self.loader.py.float32)
|
|
417
|
+
|
|
418
|
+
return n_prizes
|
|
419
|
+
|
|
420
|
+
def _compute_edge_prizes(self, text_emb: list, colls: dict):
|
|
421
|
+
"""
|
|
422
|
+
Compute the edge prizes based on the similarity between the query and edges.
|
|
423
|
+
|
|
424
|
+
Args:
|
|
425
|
+
text_emb: The textual description embedding.
|
|
426
|
+
colls: The collections of nodes, node-type specific nodes, and edges in Milvus.
|
|
427
|
+
|
|
428
|
+
Returns:
|
|
429
|
+
The prizes of the edges.
|
|
430
|
+
"""
|
|
431
|
+
# Initialize several variables
|
|
432
|
+
topk_e = min(self.topk_e, colls["edges"].num_entities)
|
|
433
|
+
e_prizes = self.loader.py.zeros(colls["edges"].num_entities, dtype=self.loader.py.float32)
|
|
434
|
+
|
|
435
|
+
# Get the actual metric type to use
|
|
436
|
+
actual_metric_type = self.metric_type or self.loader.metric_type
|
|
437
|
+
|
|
438
|
+
# Search the collection with the query embedding
|
|
439
|
+
res = colls["edges"].search(
|
|
440
|
+
data=[text_emb],
|
|
441
|
+
anns_field="feat_emb",
|
|
442
|
+
param={"metric_type": actual_metric_type},
|
|
443
|
+
limit=topk_e, # Only retrieve the top-k edges
|
|
444
|
+
output_fields=["head_id", "tail_id"],
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
# Update the prizes based on the search results
|
|
448
|
+
e_prizes[[r.id for r in res[0]]] = [r.score for r in res[0]]
|
|
449
|
+
|
|
450
|
+
# Further process the edge_prizes
|
|
451
|
+
unique_prizes, inverse_indices = self.loader.py.unique(e_prizes, return_inverse=True)
|
|
452
|
+
topk_e_values = unique_prizes[self.loader.py.argsort(-unique_prizes)[:topk_e]]
|
|
453
|
+
last_topk_e_value = topk_e
|
|
454
|
+
for k in range(topk_e):
|
|
455
|
+
indices = inverse_indices == (unique_prizes == topk_e_values[k]).nonzero()[0]
|
|
456
|
+
value = min((topk_e - k) / indices.sum().item(), last_topk_e_value)
|
|
457
|
+
e_prizes[indices] = value
|
|
458
|
+
last_topk_e_value = value * (1 - self.c_const)
|
|
459
|
+
|
|
460
|
+
return e_prizes
|
|
461
|
+
|
|
462
|
+
async def _compute_edge_prizes_async(
|
|
463
|
+
self, text_emb: list, collection_name: str, connection_manager
|
|
464
|
+
) -> dict:
|
|
465
|
+
"""
|
|
466
|
+
Compute the edge prizes asynchronously using connection manager.
|
|
467
|
+
|
|
468
|
+
Args:
|
|
469
|
+
text_emb: The textual description embedding
|
|
470
|
+
collection_name: Name of the edges collection
|
|
471
|
+
connection_manager: The MilvusConnectionManager instance
|
|
472
|
+
|
|
473
|
+
Returns:
|
|
474
|
+
The prizes of the edges
|
|
475
|
+
"""
|
|
476
|
+
# Get collection stats for initialization
|
|
477
|
+
stats = await connection_manager.async_get_collection_stats(collection_name)
|
|
478
|
+
num_entities = stats["num_entities"]
|
|
479
|
+
|
|
480
|
+
# Initialize prizes array
|
|
481
|
+
topk_e = min(self.topk_e, num_entities)
|
|
482
|
+
e_prizes = self.loader.py.zeros(num_entities, dtype=self.loader.py.float32)
|
|
483
|
+
|
|
484
|
+
# Get the actual metric type to use
|
|
485
|
+
actual_metric_type = self.metric_type or self.loader.metric_type
|
|
486
|
+
|
|
487
|
+
# Perform async search
|
|
488
|
+
results = await connection_manager.async_search(
|
|
489
|
+
collection_name=collection_name,
|
|
490
|
+
data=[text_emb],
|
|
491
|
+
anns_field="feat_emb",
|
|
492
|
+
param={"metric_type": actual_metric_type},
|
|
493
|
+
limit=topk_e,
|
|
494
|
+
output_fields=["head_id", "tail_id"],
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
# Update the prizes based on the search results
|
|
498
|
+
if results and len(results) > 0:
|
|
499
|
+
result_ids = [hit["id"] for hit in results[0]]
|
|
500
|
+
result_scores = [hit["distance"] for hit in results[0]] # Use distance/score
|
|
501
|
+
e_prizes[result_ids] = result_scores
|
|
502
|
+
|
|
503
|
+
# Process edge prizes using helper method
|
|
504
|
+
return self._process_edge_prizes(e_prizes, topk_e)
|
|
505
|
+
|
|
506
|
+
def _process_edge_prizes(self, e_prizes, topk_e):
|
|
507
|
+
"""Helper method to process edge prizes and reduce complexity."""
|
|
508
|
+
unique_prizes, inverse_indices = self.loader.py.unique(e_prizes, return_inverse=True)
|
|
509
|
+
sorted_indices = self.loader.py.argsort(-unique_prizes)[:topk_e]
|
|
510
|
+
topk_e_values = unique_prizes[sorted_indices]
|
|
511
|
+
last_topk_e_value = topk_e
|
|
512
|
+
|
|
513
|
+
for k in range(topk_e):
|
|
514
|
+
indices = inverse_indices == (unique_prizes == topk_e_values[k]).nonzero()[0]
|
|
515
|
+
value = min((topk_e - k) / indices.sum().item(), last_topk_e_value)
|
|
516
|
+
e_prizes[indices] = value
|
|
517
|
+
last_topk_e_value = value * (1 - self.c_const)
|
|
518
|
+
|
|
519
|
+
return e_prizes
|
|
520
|
+
|
|
521
|
+
def compute_prizes(self, text_emb: list, query_emb: list, colls: dict) -> dict:
|
|
522
|
+
"""
|
|
523
|
+
Compute the node prizes based on the cosine similarity between the query and nodes,
|
|
524
|
+
as well as the edge prizes based on the cosine similarity between the query and edges.
|
|
525
|
+
Note that the node and edge embeddings shall use the same embedding model and dimensions
|
|
526
|
+
with the query.
|
|
527
|
+
|
|
528
|
+
Args:
|
|
529
|
+
text_emb: The textual description embedding.
|
|
530
|
+
query_emb: The query embedding. This can be an embedding of
|
|
531
|
+
a prompt, sequence, or any other feature to be used for the subgraph extraction.
|
|
532
|
+
colls: The collections of nodes, node-type specific nodes, and edges in Milvus.
|
|
533
|
+
|
|
534
|
+
Returns:
|
|
535
|
+
The prizes of the nodes and edges.
|
|
536
|
+
"""
|
|
537
|
+
# Compute prizes for nodes
|
|
538
|
+
logger.log(logging.INFO, "_compute_node_prizes")
|
|
539
|
+
n_prizes = self._compute_node_prizes(query_emb, colls)
|
|
540
|
+
|
|
541
|
+
# Compute prizes for edges
|
|
542
|
+
logger.log(logging.INFO, "_compute_edge_prizes")
|
|
543
|
+
e_prizes = self._compute_edge_prizes(text_emb, colls)
|
|
544
|
+
|
|
545
|
+
return {"nodes": n_prizes, "edges": e_prizes}
|
|
546
|
+
|
|
547
|
+
async def compute_prizes_async(
|
|
548
|
+
self, text_emb: list, query_emb: list, cfg: dict, modality: str
|
|
549
|
+
) -> dict:
|
|
550
|
+
"""
|
|
551
|
+
Compute node and edge prizes asynchronously in parallel using sync fallback.
|
|
552
|
+
|
|
553
|
+
Args:
|
|
554
|
+
text_emb: The textual description embedding
|
|
555
|
+
query_emb: The query embedding
|
|
556
|
+
cfg: The configuration dictionary containing the Milvus setup
|
|
557
|
+
modality: The modality to use for the subgraph extraction
|
|
558
|
+
|
|
559
|
+
Returns:
|
|
560
|
+
The prizes of the nodes and edges
|
|
561
|
+
"""
|
|
562
|
+
logger.log(logging.INFO, "Computing prizes in parallel (hybrid async/sync)")
|
|
563
|
+
|
|
564
|
+
# Use existing sync method wrapped in asyncio.to_thread
|
|
565
|
+
colls = self.prepare_collections(cfg, modality)
|
|
566
|
+
return await asyncio.to_thread(self.compute_prizes, text_emb, query_emb, colls)
|
|
567
|
+
|
|
568
|
+
def compute_subgraph_costs(self, edge_index, num_nodes: int, prizes: dict):
|
|
569
|
+
"""
|
|
570
|
+
Compute the costs in constructing the subgraph proposed by G-Retriever paper.
|
|
571
|
+
|
|
572
|
+
Args:
|
|
573
|
+
edge_index: The edge index of the graph, consisting of source and destination nodes.
|
|
574
|
+
num_nodes: The number of nodes in the graph.
|
|
575
|
+
prizes: The prizes of the nodes and the edges.
|
|
576
|
+
|
|
577
|
+
Returns:
|
|
578
|
+
edges: The edges of the subgraph, consisting of edges and number of edges without
|
|
579
|
+
virtual edges.
|
|
580
|
+
prizes: The prizes of the subgraph.
|
|
581
|
+
costs: The costs of the subgraph.
|
|
582
|
+
"""
|
|
583
|
+
# Initialize several variables
|
|
584
|
+
real_ = {}
|
|
585
|
+
virt_ = {}
|
|
586
|
+
|
|
587
|
+
# Update edge cost threshold
|
|
588
|
+
updated_cost_e = min(
|
|
589
|
+
self.cost_e,
|
|
590
|
+
self.loader.py.max(prizes["edges"]).item() * (1 - self.c_const / 2),
|
|
591
|
+
)
|
|
592
|
+
|
|
593
|
+
# Masks for real and virtual edges
|
|
594
|
+
logger.log(logging.INFO, "Creating masks for real and virtual edges")
|
|
595
|
+
real_["mask"] = prizes["edges"] <= updated_cost_e
|
|
596
|
+
virt_["mask"] = ~real_["mask"]
|
|
597
|
+
|
|
598
|
+
# Real edge indices
|
|
599
|
+
logger.log(logging.INFO, "Computing real edges")
|
|
600
|
+
real_["indices"] = self.loader.py.nonzero(real_["mask"])[0]
|
|
601
|
+
real_["src"] = edge_index[0][real_["indices"]]
|
|
602
|
+
real_["dst"] = edge_index[1][real_["indices"]]
|
|
603
|
+
real_["edges"] = self.loader.py.stack([real_["src"], real_["dst"]], axis=1)
|
|
604
|
+
real_["costs"] = updated_cost_e - prizes["edges"][real_["indices"]]
|
|
605
|
+
|
|
606
|
+
# Edge index mapping: local real edge idx -> original global index
|
|
607
|
+
logger.log(logging.INFO, "Creating mapping for real edges")
|
|
608
|
+
mapping_edges = dict(
|
|
609
|
+
zip(range(len(real_["indices"])), self.loader.to_list(real_["indices"]), strict=False)
|
|
610
|
+
)
|
|
611
|
+
|
|
612
|
+
# Virtual edge handling
|
|
613
|
+
logger.log(logging.INFO, "Computing virtual edges")
|
|
614
|
+
virt_["indices"] = self.loader.py.nonzero(virt_["mask"])[0]
|
|
615
|
+
virt_["src"] = edge_index[0][virt_["indices"]]
|
|
616
|
+
virt_["dst"] = edge_index[1][virt_["indices"]]
|
|
617
|
+
virt_["prizes"] = prizes["edges"][virt_["indices"]] - updated_cost_e
|
|
618
|
+
|
|
619
|
+
# Generate virtual node IDs
|
|
620
|
+
logger.log(logging.INFO, "Generating virtual node IDs")
|
|
621
|
+
virt_["num"] = virt_["indices"].shape[0]
|
|
622
|
+
virt_["node_ids"] = self.loader.py.arange(num_nodes, num_nodes + virt_["num"])
|
|
623
|
+
|
|
624
|
+
# Virtual edges: (src → virtual), (virtual → dst)
|
|
625
|
+
logger.log(logging.INFO, "Creating virtual edges")
|
|
626
|
+
virt_["edges_1"] = self.loader.py.stack([virt_["src"], virt_["node_ids"]], axis=1)
|
|
627
|
+
virt_["edges_2"] = self.loader.py.stack([virt_["node_ids"], virt_["dst"]], axis=1)
|
|
628
|
+
virt_["edges"] = self.loader.py.concatenate([virt_["edges_1"], virt_["edges_2"]], axis=0)
|
|
629
|
+
virt_["costs"] = self.loader.py.zeros(
|
|
630
|
+
(virt_["edges"].shape[0],), dtype=real_["costs"].dtype
|
|
631
|
+
)
|
|
632
|
+
|
|
633
|
+
# Combine real and virtual edges/costs
|
|
634
|
+
logger.log(logging.INFO, "Combining real and virtual edges/costs")
|
|
635
|
+
all_edges = self.loader.py.concatenate([real_["edges"], virt_["edges"]], axis=0)
|
|
636
|
+
all_costs = self.loader.py.concatenate([real_["costs"], virt_["costs"]], axis=0)
|
|
637
|
+
|
|
638
|
+
# Final prizes
|
|
639
|
+
logger.log(logging.INFO, "Getting final prizes")
|
|
640
|
+
final_prizes = self.loader.py.concatenate([prizes["nodes"], virt_["prizes"]], axis=0)
|
|
641
|
+
|
|
642
|
+
# Mapping virtual node ID -> edge index in original graph
|
|
643
|
+
logger.log(logging.INFO, "Creating mapping for virtual nodes")
|
|
644
|
+
mapping_nodes = dict(
|
|
645
|
+
zip(
|
|
646
|
+
self.loader.to_list(virt_["node_ids"]),
|
|
647
|
+
self.loader.to_list(virt_["indices"]),
|
|
648
|
+
strict=False,
|
|
649
|
+
)
|
|
650
|
+
)
|
|
651
|
+
|
|
652
|
+
# Build return values
|
|
653
|
+
logger.log(logging.INFO, "Building return values")
|
|
654
|
+
edges_dict = {
|
|
655
|
+
"edges": all_edges,
|
|
656
|
+
"num_prior_edges": real_["edges"].shape[0],
|
|
657
|
+
}
|
|
658
|
+
mapping = {
|
|
659
|
+
"edges": mapping_edges,
|
|
660
|
+
"nodes": mapping_nodes,
|
|
661
|
+
}
|
|
662
|
+
|
|
663
|
+
return edges_dict, final_prizes, all_costs, mapping
|
|
664
|
+
|
|
665
|
+
def get_subgraph_nodes_edges(
|
|
666
|
+
self, num_nodes: int, vertices, edges_dict: dict, mapping: dict
|
|
667
|
+
) -> dict:
|
|
668
|
+
"""
|
|
669
|
+
Get the selected nodes and edges of the subgraph based on the vertices and edges computed
|
|
670
|
+
by the PCST algorithm.
|
|
671
|
+
|
|
672
|
+
Args:
|
|
673
|
+
num_nodes: The number of nodes in the graph.
|
|
674
|
+
vertices: The vertices selected by the PCST algorithm.
|
|
675
|
+
edges_dict: A dictionary containing the edges and the number of prior edges.
|
|
676
|
+
mapping: A dictionary containing the mapping of nodes and edges.
|
|
677
|
+
|
|
678
|
+
Returns:
|
|
679
|
+
The selected nodes and edges of the extracted subgraph.
|
|
680
|
+
"""
|
|
681
|
+
# Get edges information
|
|
682
|
+
edges = edges_dict["edges"]
|
|
683
|
+
num_prior_edges = edges_dict["num_prior_edges"]
|
|
684
|
+
|
|
685
|
+
# Retrieve the selected nodes and edges based on the given vertices and edges
|
|
686
|
+
subgraph_nodes = vertices[vertices < num_nodes]
|
|
687
|
+
subgraph_edges = [mapping["edges"][e.item()] for e in edges if e < num_prior_edges]
|
|
688
|
+
virtual_vertices = vertices[vertices >= num_nodes]
|
|
689
|
+
if len(virtual_vertices) > 0:
|
|
690
|
+
virtual_edges = [mapping["nodes"][i.item()] for i in virtual_vertices]
|
|
691
|
+
subgraph_edges = self.loader.py.array(subgraph_edges + virtual_edges)
|
|
692
|
+
edge_index = edges_dict["edge_index"][:, subgraph_edges]
|
|
693
|
+
subgraph_nodes = self.loader.py.unique(
|
|
694
|
+
self.loader.py.concatenate([subgraph_nodes, edge_index[0], edge_index[1]])
|
|
695
|
+
)
|
|
696
|
+
|
|
697
|
+
return {"nodes": subgraph_nodes, "edges": subgraph_edges}
|
|
698
|
+
|
|
699
|
+
def extract_subgraph(self, text_emb: list, query_emb: list, modality: str, cfg: dict) -> dict:
|
|
700
|
+
"""
|
|
701
|
+
Perform the Prize-Collecting Steiner Tree (PCST) algorithm to extract the subgraph.
|
|
702
|
+
|
|
703
|
+
Args:
|
|
704
|
+
text_emb: The textual description embedding.
|
|
705
|
+
query_emb: The query embedding. This can be an embedding of
|
|
706
|
+
a prompt, sequence, or any other feature to be used for the subgraph extraction.
|
|
707
|
+
modality: The modality to use for the subgraph extraction
|
|
708
|
+
(e.g., "text", "sequence", "smiles").
|
|
709
|
+
cfg: The configuration dictionary containing the Milvus setup.
|
|
710
|
+
|
|
711
|
+
Returns:
|
|
712
|
+
The selected nodes and edges of the subgraph.
|
|
713
|
+
"""
|
|
714
|
+
# Load the collections for nodes
|
|
715
|
+
logger.log(logging.INFO, "Preparing collections")
|
|
716
|
+
colls = self.prepare_collections(cfg, modality)
|
|
717
|
+
|
|
718
|
+
# Load edge index directly from Milvus (replaces pickle cache)
|
|
719
|
+
logger.log(logging.INFO, "Loading edge index from Milvus")
|
|
720
|
+
edge_index = self.load_edge_index(cfg)
|
|
721
|
+
|
|
722
|
+
# Assert the topk and topk_e values for subgraph retrieval
|
|
723
|
+
assert self.topk > 0, "topk must be greater than or equal to 0"
|
|
724
|
+
assert self.topk_e > 0, "topk_e must be greater than or equal to 0"
|
|
725
|
+
|
|
726
|
+
# Retrieve the top-k nodes and edges based on the query embedding
|
|
727
|
+
logger.log(logging.INFO, "compute_prizes")
|
|
728
|
+
prizes = self.compute_prizes(text_emb, query_emb, colls)
|
|
729
|
+
|
|
730
|
+
# Compute costs in constructing the subgraph
|
|
731
|
+
logger.log(logging.INFO, "compute_subgraph_costs")
|
|
732
|
+
edges_dict, prizes, costs, mapping = self.compute_subgraph_costs(
|
|
733
|
+
edge_index, colls["nodes"].num_entities, prizes
|
|
734
|
+
)
|
|
735
|
+
|
|
736
|
+
# Retrieve the subgraph using the PCST algorithm
|
|
737
|
+
logger.log(logging.INFO, "Running PCST algorithm")
|
|
738
|
+
result_vertices, result_edges = pcst_fast.pcst_fast(
|
|
739
|
+
edges_dict["edges"].tolist(),
|
|
740
|
+
prizes.tolist(),
|
|
741
|
+
costs.tolist(),
|
|
742
|
+
self.root,
|
|
743
|
+
self.num_clusters,
|
|
744
|
+
self.pruning,
|
|
745
|
+
self.verbosity_level,
|
|
746
|
+
)
|
|
747
|
+
|
|
748
|
+
# Get subgraph nodes and edges based on the result of the PCST algorithm
|
|
749
|
+
logger.log(logging.INFO, "Getting subgraph nodes and edges")
|
|
750
|
+
subgraph = self.get_subgraph_nodes_edges(
|
|
751
|
+
colls["nodes"].num_entities,
|
|
752
|
+
self.loader.py.asarray(result_vertices),
|
|
753
|
+
{
|
|
754
|
+
"edges": self.loader.py.asarray(result_edges),
|
|
755
|
+
"num_prior_edges": edges_dict["num_prior_edges"],
|
|
756
|
+
"edge_index": edge_index,
|
|
757
|
+
},
|
|
758
|
+
mapping,
|
|
759
|
+
)
|
|
760
|
+
print(subgraph)
|
|
761
|
+
|
|
762
|
+
return subgraph
|