aiagents4pharma 1.43.0__py3-none-any.whl → 1.45.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (290) hide show
  1. aiagents4pharma/__init__.py +2 -2
  2. aiagents4pharma/talk2aiagents4pharma/.dockerignore +13 -0
  3. aiagents4pharma/talk2aiagents4pharma/Dockerfile +105 -0
  4. aiagents4pharma/talk2aiagents4pharma/README.md +1 -0
  5. aiagents4pharma/talk2aiagents4pharma/__init__.py +4 -5
  6. aiagents4pharma/talk2aiagents4pharma/agents/__init__.py +3 -2
  7. aiagents4pharma/talk2aiagents4pharma/agents/main_agent.py +24 -23
  8. aiagents4pharma/talk2aiagents4pharma/configs/__init__.py +2 -2
  9. aiagents4pharma/talk2aiagents4pharma/configs/agents/__init__.py +2 -2
  10. aiagents4pharma/talk2aiagents4pharma/configs/agents/main_agent/default.yaml +2 -2
  11. aiagents4pharma/talk2aiagents4pharma/configs/config.yaml +1 -1
  12. aiagents4pharma/talk2aiagents4pharma/docker-compose/cpu/.env.example +23 -0
  13. aiagents4pharma/talk2aiagents4pharma/docker-compose/cpu/docker-compose.yml +93 -0
  14. aiagents4pharma/talk2aiagents4pharma/docker-compose/gpu/.env.example +23 -0
  15. aiagents4pharma/talk2aiagents4pharma/docker-compose/gpu/docker-compose.yml +108 -0
  16. aiagents4pharma/talk2aiagents4pharma/install.md +127 -0
  17. aiagents4pharma/talk2aiagents4pharma/states/__init__.py +3 -2
  18. aiagents4pharma/talk2aiagents4pharma/states/state_talk2aiagents4pharma.py +5 -3
  19. aiagents4pharma/talk2aiagents4pharma/tests/__init__.py +2 -2
  20. aiagents4pharma/talk2aiagents4pharma/tests/test_main_agent.py +72 -50
  21. aiagents4pharma/talk2biomodels/.dockerignore +13 -0
  22. aiagents4pharma/talk2biomodels/Dockerfile +104 -0
  23. aiagents4pharma/talk2biomodels/README.md +1 -0
  24. aiagents4pharma/talk2biomodels/__init__.py +4 -8
  25. aiagents4pharma/talk2biomodels/agents/__init__.py +3 -2
  26. aiagents4pharma/talk2biomodels/agents/t2b_agent.py +47 -42
  27. aiagents4pharma/talk2biomodels/api/__init__.py +4 -5
  28. aiagents4pharma/talk2biomodels/api/kegg.py +14 -10
  29. aiagents4pharma/talk2biomodels/api/ols.py +13 -10
  30. aiagents4pharma/talk2biomodels/api/uniprot.py +7 -6
  31. aiagents4pharma/talk2biomodels/configs/__init__.py +3 -4
  32. aiagents4pharma/talk2biomodels/configs/agents/__init__.py +2 -2
  33. aiagents4pharma/talk2biomodels/configs/agents/t2b_agent/__init__.py +2 -2
  34. aiagents4pharma/talk2biomodels/configs/agents/t2b_agent/default.yaml +1 -1
  35. aiagents4pharma/talk2biomodels/configs/config.yaml +1 -1
  36. aiagents4pharma/talk2biomodels/configs/tools/__init__.py +4 -5
  37. aiagents4pharma/talk2biomodels/configs/tools/ask_question/__init__.py +2 -2
  38. aiagents4pharma/talk2biomodels/configs/tools/ask_question/default.yaml +1 -2
  39. aiagents4pharma/talk2biomodels/configs/tools/custom_plotter/__init__.py +2 -2
  40. aiagents4pharma/talk2biomodels/configs/tools/custom_plotter/default.yaml +1 -1
  41. aiagents4pharma/talk2biomodels/configs/tools/get_annotation/__init__.py +2 -2
  42. aiagents4pharma/talk2biomodels/configs/tools/get_annotation/default.yaml +1 -1
  43. aiagents4pharma/talk2biomodels/install.md +63 -0
  44. aiagents4pharma/talk2biomodels/models/__init__.py +4 -4
  45. aiagents4pharma/talk2biomodels/models/basico_model.py +36 -28
  46. aiagents4pharma/talk2biomodels/models/sys_bio_model.py +13 -10
  47. aiagents4pharma/talk2biomodels/states/__init__.py +3 -2
  48. aiagents4pharma/talk2biomodels/states/state_talk2biomodels.py +12 -8
  49. aiagents4pharma/talk2biomodels/tests/BIOMD0000000449_url.xml +1585 -0
  50. aiagents4pharma/talk2biomodels/tests/__init__.py +2 -2
  51. aiagents4pharma/talk2biomodels/tests/article_on_model_537.pdf +0 -0
  52. aiagents4pharma/talk2biomodels/tests/test_api.py +18 -14
  53. aiagents4pharma/talk2biomodels/tests/test_ask_question.py +8 -9
  54. aiagents4pharma/talk2biomodels/tests/test_basico_model.py +15 -9
  55. aiagents4pharma/talk2biomodels/tests/test_get_annotation.py +54 -55
  56. aiagents4pharma/talk2biomodels/tests/test_getmodelinfo.py +28 -27
  57. aiagents4pharma/talk2biomodels/tests/test_integration.py +21 -33
  58. aiagents4pharma/talk2biomodels/tests/test_load_biomodel.py +14 -11
  59. aiagents4pharma/talk2biomodels/tests/test_param_scan.py +21 -20
  60. aiagents4pharma/talk2biomodels/tests/test_query_article.py +129 -29
  61. aiagents4pharma/talk2biomodels/tests/test_search_models.py +9 -13
  62. aiagents4pharma/talk2biomodels/tests/test_simulate_model.py +16 -15
  63. aiagents4pharma/talk2biomodels/tests/test_steady_state.py +12 -22
  64. aiagents4pharma/talk2biomodels/tests/test_sys_bio_model.py +33 -29
  65. aiagents4pharma/talk2biomodels/tools/__init__.py +15 -12
  66. aiagents4pharma/talk2biomodels/tools/ask_question.py +42 -32
  67. aiagents4pharma/talk2biomodels/tools/custom_plotter.py +51 -43
  68. aiagents4pharma/talk2biomodels/tools/get_annotation.py +99 -75
  69. aiagents4pharma/talk2biomodels/tools/get_modelinfo.py +57 -51
  70. aiagents4pharma/talk2biomodels/tools/load_arguments.py +52 -32
  71. aiagents4pharma/talk2biomodels/tools/load_biomodel.py +8 -2
  72. aiagents4pharma/talk2biomodels/tools/parameter_scan.py +107 -90
  73. aiagents4pharma/talk2biomodels/tools/query_article.py +14 -13
  74. aiagents4pharma/talk2biomodels/tools/search_models.py +37 -26
  75. aiagents4pharma/talk2biomodels/tools/simulate_model.py +47 -37
  76. aiagents4pharma/talk2biomodels/tools/steady_state.py +76 -58
  77. aiagents4pharma/talk2biomodels/tools/utils.py +4 -3
  78. aiagents4pharma/talk2cells/README.md +1 -0
  79. aiagents4pharma/talk2cells/__init__.py +4 -5
  80. aiagents4pharma/talk2cells/agents/__init__.py +3 -2
  81. aiagents4pharma/talk2cells/agents/scp_agent.py +21 -19
  82. aiagents4pharma/talk2cells/states/__init__.py +3 -2
  83. aiagents4pharma/talk2cells/states/state_talk2cells.py +4 -2
  84. aiagents4pharma/talk2cells/tests/scp_agent/test_scp_agent.py +8 -9
  85. aiagents4pharma/talk2cells/tools/__init__.py +3 -2
  86. aiagents4pharma/talk2cells/tools/scp_agent/__init__.py +4 -4
  87. aiagents4pharma/talk2cells/tools/scp_agent/display_studies.py +5 -3
  88. aiagents4pharma/talk2cells/tools/scp_agent/search_studies.py +21 -22
  89. aiagents4pharma/talk2knowledgegraphs/.dockerignore +13 -0
  90. aiagents4pharma/talk2knowledgegraphs/Dockerfile +103 -0
  91. aiagents4pharma/talk2knowledgegraphs/README.md +1 -0
  92. aiagents4pharma/talk2knowledgegraphs/__init__.py +4 -7
  93. aiagents4pharma/talk2knowledgegraphs/agents/__init__.py +3 -2
  94. aiagents4pharma/talk2knowledgegraphs/agents/t2kg_agent.py +40 -30
  95. aiagents4pharma/talk2knowledgegraphs/configs/__init__.py +3 -6
  96. aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/__init__.py +2 -2
  97. aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/default.yaml +8 -8
  98. aiagents4pharma/talk2knowledgegraphs/configs/app/__init__.py +3 -2
  99. aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/__init__.py +2 -2
  100. aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/default.yaml +1 -1
  101. aiagents4pharma/talk2knowledgegraphs/configs/config.yaml +1 -1
  102. aiagents4pharma/talk2knowledgegraphs/configs/tools/__init__.py +4 -5
  103. aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/__init__.py +2 -2
  104. aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/default.yaml +1 -1
  105. aiagents4pharma/talk2knowledgegraphs/configs/tools/multimodal_subgraph_extraction/default.yaml +17 -2
  106. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/__init__.py +2 -2
  107. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/default.yaml +1 -1
  108. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/__init__.py +2 -2
  109. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/default.yaml +1 -1
  110. aiagents4pharma/talk2knowledgegraphs/configs/utils/enrichments/ols_terms/default.yaml +1 -1
  111. aiagents4pharma/talk2knowledgegraphs/configs/utils/enrichments/reactome_pathways/default.yaml +1 -1
  112. aiagents4pharma/talk2knowledgegraphs/configs/utils/enrichments/uniprot_proteins/default.yaml +1 -1
  113. aiagents4pharma/talk2knowledgegraphs/configs/utils/pubchem_utils/default.yaml +1 -1
  114. aiagents4pharma/talk2knowledgegraphs/datasets/__init__.py +4 -6
  115. aiagents4pharma/talk2knowledgegraphs/datasets/biobridge_primekg.py +115 -67
  116. aiagents4pharma/talk2knowledgegraphs/datasets/dataset.py +2 -0
  117. aiagents4pharma/talk2knowledgegraphs/datasets/primekg.py +35 -24
  118. aiagents4pharma/talk2knowledgegraphs/datasets/starkqa_primekg.py +29 -21
  119. aiagents4pharma/talk2knowledgegraphs/docker-compose/cpu/.env.example +23 -0
  120. aiagents4pharma/talk2knowledgegraphs/docker-compose/cpu/docker-compose.yml +93 -0
  121. aiagents4pharma/talk2knowledgegraphs/docker-compose/gpu/.env.example +23 -0
  122. aiagents4pharma/talk2knowledgegraphs/docker-compose/gpu/docker-compose.yml +108 -0
  123. aiagents4pharma/talk2knowledgegraphs/entrypoint.sh +190 -0
  124. aiagents4pharma/talk2knowledgegraphs/install.md +140 -0
  125. aiagents4pharma/talk2knowledgegraphs/milvus_data_dump.py +31 -65
  126. aiagents4pharma/talk2knowledgegraphs/states/__init__.py +3 -2
  127. aiagents4pharma/talk2knowledgegraphs/states/state_talk2knowledgegraphs.py +1 -0
  128. aiagents4pharma/talk2knowledgegraphs/tests/test_agents_t2kg_agent.py +65 -40
  129. aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_biobridge_primekg.py +54 -48
  130. aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_dataset.py +4 -0
  131. aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_primekg.py +17 -4
  132. aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_starkqa_primekg.py +33 -24
  133. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_graphrag_reasoning.py +116 -69
  134. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_milvus_multimodal_subgraph_extraction.py +736 -413
  135. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_multimodal_subgraph_extraction.py +22 -15
  136. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_extraction.py +19 -12
  137. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_summarization.py +95 -48
  138. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_embeddings.py +4 -0
  139. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_huggingface.py +5 -0
  140. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_nim_molmim.py +13 -18
  141. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_ollama.py +10 -3
  142. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_enrichments.py +4 -3
  143. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ollama.py +3 -2
  144. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ols.py +1 -0
  145. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_pubchem.py +9 -4
  146. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_reactome.py +6 -6
  147. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_uniprot.py +4 -0
  148. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_extractions_milvus_multimodal_pcst.py +442 -42
  149. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_kg_utils.py +3 -4
  150. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_pubchem_utils.py +10 -6
  151. aiagents4pharma/talk2knowledgegraphs/tools/__init__.py +10 -7
  152. aiagents4pharma/talk2knowledgegraphs/tools/graphrag_reasoning.py +15 -20
  153. aiagents4pharma/talk2knowledgegraphs/tools/milvus_multimodal_subgraph_extraction.py +245 -205
  154. aiagents4pharma/talk2knowledgegraphs/tools/multimodal_subgraph_extraction.py +92 -90
  155. aiagents4pharma/talk2knowledgegraphs/tools/subgraph_extraction.py +25 -37
  156. aiagents4pharma/talk2knowledgegraphs/tools/subgraph_summarization.py +10 -13
  157. aiagents4pharma/talk2knowledgegraphs/utils/__init__.py +4 -7
  158. aiagents4pharma/talk2knowledgegraphs/utils/embeddings/__init__.py +4 -7
  159. aiagents4pharma/talk2knowledgegraphs/utils/embeddings/embeddings.py +4 -0
  160. aiagents4pharma/talk2knowledgegraphs/utils/embeddings/huggingface.py +11 -14
  161. aiagents4pharma/talk2knowledgegraphs/utils/embeddings/nim_molmim.py +7 -7
  162. aiagents4pharma/talk2knowledgegraphs/utils/embeddings/ollama.py +12 -6
  163. aiagents4pharma/talk2knowledgegraphs/utils/embeddings/sentence_transformer.py +8 -6
  164. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/__init__.py +9 -6
  165. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/enrichments.py +1 -0
  166. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/ollama.py +15 -9
  167. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/ols_terms.py +23 -20
  168. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/pubchem_strings.py +12 -10
  169. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/reactome_pathways.py +16 -10
  170. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/uniprot_proteins.py +26 -18
  171. aiagents4pharma/talk2knowledgegraphs/utils/extractions/__init__.py +4 -5
  172. aiagents4pharma/talk2knowledgegraphs/utils/extractions/milvus_multimodal_pcst.py +218 -81
  173. aiagents4pharma/talk2knowledgegraphs/utils/extractions/multimodal_pcst.py +53 -47
  174. aiagents4pharma/talk2knowledgegraphs/utils/extractions/pcst.py +18 -14
  175. aiagents4pharma/talk2knowledgegraphs/utils/kg_utils.py +22 -23
  176. aiagents4pharma/talk2knowledgegraphs/utils/pubchem_utils.py +11 -10
  177. aiagents4pharma/talk2scholars/.dockerignore +13 -0
  178. aiagents4pharma/talk2scholars/Dockerfile +104 -0
  179. aiagents4pharma/talk2scholars/README.md +1 -0
  180. aiagents4pharma/talk2scholars/agents/__init__.py +1 -5
  181. aiagents4pharma/talk2scholars/agents/main_agent.py +6 -4
  182. aiagents4pharma/talk2scholars/agents/paper_download_agent.py +5 -4
  183. aiagents4pharma/talk2scholars/agents/pdf_agent.py +4 -2
  184. aiagents4pharma/talk2scholars/agents/s2_agent.py +2 -2
  185. aiagents4pharma/talk2scholars/agents/zotero_agent.py +10 -11
  186. aiagents4pharma/talk2scholars/configs/__init__.py +1 -3
  187. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/__init__.py +1 -4
  188. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/main_agent/default.yaml +1 -1
  189. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/pdf_agent/default.yaml +1 -1
  190. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/s2_agent/default.yaml +8 -8
  191. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/zotero_agent/default.yaml +7 -7
  192. aiagents4pharma/talk2scholars/configs/tools/__init__.py +8 -6
  193. aiagents4pharma/talk2scholars/docker-compose/cpu/.env.example +21 -0
  194. aiagents4pharma/talk2scholars/docker-compose/cpu/docker-compose.yml +90 -0
  195. aiagents4pharma/talk2scholars/docker-compose/gpu/.env.example +21 -0
  196. aiagents4pharma/talk2scholars/docker-compose/gpu/docker-compose.yml +105 -0
  197. aiagents4pharma/talk2scholars/install.md +122 -0
  198. aiagents4pharma/talk2scholars/state/state_talk2scholars.py +8 -8
  199. aiagents4pharma/talk2scholars/tests/{test_main_agent.py → test_agents_main_agent.py} +41 -23
  200. aiagents4pharma/talk2scholars/tests/{test_paper_download_agent.py → test_agents_paper_agents_download_agent.py} +10 -16
  201. aiagents4pharma/talk2scholars/tests/{test_pdf_agent.py → test_agents_pdf_agent.py} +6 -10
  202. aiagents4pharma/talk2scholars/tests/{test_s2_agent.py → test_agents_s2_agent.py} +8 -16
  203. aiagents4pharma/talk2scholars/tests/{test_zotero_agent.py → test_agents_zotero_agent.py} +5 -7
  204. aiagents4pharma/talk2scholars/tests/{test_s2_display_dataframe.py → test_s2_tools_display_dataframe.py} +6 -7
  205. aiagents4pharma/talk2scholars/tests/{test_s2_query_dataframe.py → test_s2_tools_query_dataframe.py} +5 -15
  206. aiagents4pharma/talk2scholars/tests/{test_paper_downloader.py → test_tools_paper_downloader.py} +25 -63
  207. aiagents4pharma/talk2scholars/tests/{test_question_and_answer_tool.py → test_tools_question_and_answer_tool.py} +2 -6
  208. aiagents4pharma/talk2scholars/tests/{test_s2_multi.py → test_tools_s2_multi.py} +5 -5
  209. aiagents4pharma/talk2scholars/tests/{test_s2_retrieve.py → test_tools_s2_retrieve.py} +2 -1
  210. aiagents4pharma/talk2scholars/tests/{test_s2_search.py → test_tools_s2_search.py} +5 -5
  211. aiagents4pharma/talk2scholars/tests/{test_s2_single.py → test_tools_s2_single.py} +5 -5
  212. aiagents4pharma/talk2scholars/tests/{test_arxiv_downloader.py → test_utils_arxiv_downloader.py} +16 -25
  213. aiagents4pharma/talk2scholars/tests/{test_base_paper_downloader.py → test_utils_base_paper_downloader.py} +25 -47
  214. aiagents4pharma/talk2scholars/tests/{test_biorxiv_downloader.py → test_utils_biorxiv_downloader.py} +14 -42
  215. aiagents4pharma/talk2scholars/tests/{test_medrxiv_downloader.py → test_utils_medrxiv_downloader.py} +15 -49
  216. aiagents4pharma/talk2scholars/tests/{test_nvidia_nim_reranker.py → test_utils_nvidia_nim_reranker.py} +6 -16
  217. aiagents4pharma/talk2scholars/tests/{test_pdf_answer_formatter.py → test_utils_pdf_answer_formatter.py} +1 -0
  218. aiagents4pharma/talk2scholars/tests/{test_pdf_batch_processor.py → test_utils_pdf_batch_processor.py} +6 -15
  219. aiagents4pharma/talk2scholars/tests/{test_pdf_collection_manager.py → test_utils_pdf_collection_manager.py} +34 -11
  220. aiagents4pharma/talk2scholars/tests/{test_pdf_document_processor.py → test_utils_pdf_document_processor.py} +2 -3
  221. aiagents4pharma/talk2scholars/tests/{test_pdf_generate_answer.py → test_utils_pdf_generate_answer.py} +3 -6
  222. aiagents4pharma/talk2scholars/tests/{test_pdf_gpu_detection.py → test_utils_pdf_gpu_detection.py} +5 -16
  223. aiagents4pharma/talk2scholars/tests/{test_pdf_rag_pipeline.py → test_utils_pdf_rag_pipeline.py} +7 -17
  224. aiagents4pharma/talk2scholars/tests/{test_pdf_retrieve_chunks.py → test_utils_pdf_retrieve_chunks.py} +4 -11
  225. aiagents4pharma/talk2scholars/tests/{test_pdf_singleton_manager.py → test_utils_pdf_singleton_manager.py} +26 -23
  226. aiagents4pharma/talk2scholars/tests/{test_pdf_vector_normalization.py → test_utils_pdf_vector_normalization.py} +1 -1
  227. aiagents4pharma/talk2scholars/tests/{test_pdf_vector_store.py → test_utils_pdf_vector_store.py} +27 -55
  228. aiagents4pharma/talk2scholars/tests/{test_pubmed_downloader.py → test_utils_pubmed_downloader.py} +31 -91
  229. aiagents4pharma/talk2scholars/tests/{test_read_helper_utils.py → test_utils_read_helper_utils.py} +2 -6
  230. aiagents4pharma/talk2scholars/tests/{test_s2_utils_ext_ids.py → test_utils_s2_utils_ext_ids.py} +5 -15
  231. aiagents4pharma/talk2scholars/tests/{test_zotero_human_in_the_loop.py → test_utils_zotero_human_in_the_loop.py} +6 -13
  232. aiagents4pharma/talk2scholars/tests/{test_zotero_path.py → test_utils_zotero_path.py} +53 -45
  233. aiagents4pharma/talk2scholars/tests/{test_zotero_read.py → test_utils_zotero_read.py} +30 -91
  234. aiagents4pharma/talk2scholars/tests/{test_zotero_write.py → test_utils_zotero_write.py} +6 -16
  235. aiagents4pharma/talk2scholars/tools/__init__.py +1 -4
  236. aiagents4pharma/talk2scholars/tools/paper_download/paper_downloader.py +20 -35
  237. aiagents4pharma/talk2scholars/tools/paper_download/utils/__init__.py +7 -5
  238. aiagents4pharma/talk2scholars/tools/paper_download/utils/arxiv_downloader.py +9 -11
  239. aiagents4pharma/talk2scholars/tools/paper_download/utils/base_paper_downloader.py +14 -21
  240. aiagents4pharma/talk2scholars/tools/paper_download/utils/biorxiv_downloader.py +14 -22
  241. aiagents4pharma/talk2scholars/tools/paper_download/utils/medrxiv_downloader.py +11 -13
  242. aiagents4pharma/talk2scholars/tools/paper_download/utils/pubmed_downloader.py +14 -28
  243. aiagents4pharma/talk2scholars/tools/pdf/question_and_answer.py +4 -8
  244. aiagents4pharma/talk2scholars/tools/pdf/utils/__init__.py +16 -14
  245. aiagents4pharma/talk2scholars/tools/pdf/utils/answer_formatter.py +4 -4
  246. aiagents4pharma/talk2scholars/tools/pdf/utils/batch_processor.py +15 -17
  247. aiagents4pharma/talk2scholars/tools/pdf/utils/collection_manager.py +2 -2
  248. aiagents4pharma/talk2scholars/tools/pdf/utils/document_processor.py +5 -5
  249. aiagents4pharma/talk2scholars/tools/pdf/utils/generate_answer.py +4 -4
  250. aiagents4pharma/talk2scholars/tools/pdf/utils/get_vectorstore.py +2 -6
  251. aiagents4pharma/talk2scholars/tools/pdf/utils/gpu_detection.py +5 -9
  252. aiagents4pharma/talk2scholars/tools/pdf/utils/nvidia_nim_reranker.py +4 -4
  253. aiagents4pharma/talk2scholars/tools/pdf/utils/paper_loader.py +2 -2
  254. aiagents4pharma/talk2scholars/tools/pdf/utils/rag_pipeline.py +6 -15
  255. aiagents4pharma/talk2scholars/tools/pdf/utils/retrieve_chunks.py +7 -15
  256. aiagents4pharma/talk2scholars/tools/pdf/utils/singleton_manager.py +2 -2
  257. aiagents4pharma/talk2scholars/tools/pdf/utils/tool_helper.py +3 -4
  258. aiagents4pharma/talk2scholars/tools/pdf/utils/vector_normalization.py +8 -17
  259. aiagents4pharma/talk2scholars/tools/pdf/utils/vector_store.py +17 -33
  260. aiagents4pharma/talk2scholars/tools/s2/__init__.py +8 -6
  261. aiagents4pharma/talk2scholars/tools/s2/display_dataframe.py +3 -7
  262. aiagents4pharma/talk2scholars/tools/s2/multi_paper_rec.py +7 -6
  263. aiagents4pharma/talk2scholars/tools/s2/query_dataframe.py +5 -12
  264. aiagents4pharma/talk2scholars/tools/s2/retrieve_semantic_scholar_paper_id.py +2 -4
  265. aiagents4pharma/talk2scholars/tools/s2/search.py +6 -6
  266. aiagents4pharma/talk2scholars/tools/s2/single_paper_rec.py +5 -3
  267. aiagents4pharma/talk2scholars/tools/s2/utils/__init__.py +1 -3
  268. aiagents4pharma/talk2scholars/tools/s2/utils/multi_helper.py +12 -18
  269. aiagents4pharma/talk2scholars/tools/s2/utils/search_helper.py +11 -18
  270. aiagents4pharma/talk2scholars/tools/s2/utils/single_helper.py +11 -16
  271. aiagents4pharma/talk2scholars/tools/zotero/__init__.py +1 -4
  272. aiagents4pharma/talk2scholars/tools/zotero/utils/__init__.py +1 -4
  273. aiagents4pharma/talk2scholars/tools/zotero/utils/read_helper.py +21 -39
  274. aiagents4pharma/talk2scholars/tools/zotero/utils/review_helper.py +2 -6
  275. aiagents4pharma/talk2scholars/tools/zotero/utils/write_helper.py +8 -11
  276. aiagents4pharma/talk2scholars/tools/zotero/utils/zotero_path.py +4 -12
  277. aiagents4pharma/talk2scholars/tools/zotero/utils/zotero_pdf_downloader.py +13 -27
  278. aiagents4pharma/talk2scholars/tools/zotero/zotero_read.py +4 -7
  279. aiagents4pharma/talk2scholars/tools/zotero/zotero_review.py +8 -10
  280. aiagents4pharma/talk2scholars/tools/zotero/zotero_write.py +3 -2
  281. {aiagents4pharma-1.43.0.dist-info → aiagents4pharma-1.45.0.dist-info}/METADATA +115 -50
  282. aiagents4pharma-1.45.0.dist-info/RECORD +324 -0
  283. {aiagents4pharma-1.43.0.dist-info → aiagents4pharma-1.45.0.dist-info}/WHEEL +1 -2
  284. aiagents4pharma-1.43.0.dist-info/RECORD +0 -293
  285. aiagents4pharma-1.43.0.dist-info/top_level.txt +0 -1
  286. /aiagents4pharma/talk2scholars/tests/{test_state.py → test_states_state.py} +0 -0
  287. /aiagents4pharma/talk2scholars/tests/{test_pdf_paper_loader.py → test_utils_pdf_paper_loader.py} +0 -0
  288. /aiagents4pharma/talk2scholars/tests/{test_tool_helper_utils.py → test_utils_tool_helper_utils.py} +0 -0
  289. /aiagents4pharma/talk2scholars/tests/{test_zotero_pdf_downloader_utils.py → test_utils_zotero_pdf_downloader_utils.py} +0 -0
  290. {aiagents4pharma-1.43.0.dist-info → aiagents4pharma-1.45.0.dist-info}/licenses/LICENSE +0 -0
@@ -2,24 +2,155 @@
2
2
  Exctraction of multimodal subgraph using Prize-Collecting Steiner Tree (PCST) algorithm.
3
3
  """
4
4
 
5
- from typing import Tuple, NamedTuple
6
5
  import logging
7
6
  import pickle
7
+ import platform
8
+ import subprocess
9
+ from typing import NamedTuple
10
+
11
+ import numpy as np
8
12
  import pandas as pd
9
13
  import pcst_fast
10
14
  from pymilvus import Collection
15
+
11
16
  try:
12
- import cupy as py
13
17
  import cudf
14
- df = cudf
18
+ import cupy as cp
19
+
20
+ CUDF_AVAILABLE = True
15
21
  except ImportError:
16
- import numpy as py
17
- df = pd
22
+ CUDF_AVAILABLE = False
23
+ cudf = None
24
+ cp = None
18
25
 
19
26
  # Initialize logger
20
27
  logging.basicConfig(level=logging.INFO)
21
28
  logger = logging.getLogger(__name__)
22
29
 
30
+
31
+ class SystemDetector:
32
+ """Detect system capabilities and choose appropriate libraries."""
33
+
34
+ def __init__(self):
35
+ self.os_type = platform.system().lower() # 'windows', 'linux', 'darwin'
36
+ self.architecture = platform.machine().lower() # 'x86_64', 'arm64', etc.
37
+ self.has_nvidia_gpu = self._detect_nvidia_gpu()
38
+ self.use_gpu = self.has_nvidia_gpu and self.os_type != "darwin" # No CUDA on macOS
39
+
40
+ logger.info("System Detection Results:")
41
+ logger.info(" OS: %s", self.os_type)
42
+ logger.info(" Architecture: %s", self.architecture)
43
+ logger.info(" NVIDIA GPU detected: %s", self.has_nvidia_gpu)
44
+ logger.info(" Will use GPU acceleration: %s", self.use_gpu)
45
+
46
+ def _detect_nvidia_gpu(self) -> bool:
47
+ """Detect if NVIDIA GPU is available."""
48
+ try:
49
+ # Try nvidia-smi command
50
+ result = subprocess.run(
51
+ ["nvidia-smi"], capture_output=True, text=True, timeout=10, check=False
52
+ )
53
+ return result.returncode == 0
54
+ except (
55
+ subprocess.TimeoutExpired,
56
+ FileNotFoundError,
57
+ subprocess.SubprocessError,
58
+ ):
59
+ return False
60
+
61
+ def get_system_info(self) -> dict:
62
+ """Get comprehensive system information."""
63
+ return {
64
+ "os_type": self.os_type,
65
+ "architecture": self.architecture,
66
+ "has_nvidia_gpu": self.has_nvidia_gpu,
67
+ "use_gpu": self.use_gpu,
68
+ }
69
+
70
+ def is_gpu_compatible(self) -> bool:
71
+ """Check if the system is compatible with GPU acceleration."""
72
+ return self.has_nvidia_gpu and self.os_type != "darwin"
73
+
74
+
75
+ class DynamicLibraryLoader:
76
+ """Dynamically load libraries based on system capabilities."""
77
+
78
+ def __init__(self, detector: SystemDetector):
79
+ self.detector = detector
80
+ self.use_gpu = detector.use_gpu
81
+
82
+ # Initialize attributes that will be set later
83
+ self.py = None
84
+ self.df = None
85
+ self.pd = None
86
+ self.np = None
87
+ self.cudf = None
88
+ self.cp = None
89
+
90
+ # Import libraries based on system capabilities
91
+ self._import_libraries()
92
+
93
+ # Dynamic settings based on hardware
94
+ self.normalize_vectors = self.use_gpu # Only normalize for GPU
95
+ self.metric_type = "IP" if self.use_gpu else "COSINE"
96
+
97
+ logger.info("Library Configuration:")
98
+ logger.info(" Using GPU acceleration: %s", self.use_gpu)
99
+ logger.info(" Vector normalization: %s", self.normalize_vectors)
100
+ logger.info(" Metric type: %s", self.metric_type)
101
+
102
+ def _import_libraries(self):
103
+ """Dynamically import libraries based on system capabilities."""
104
+ # Set base libraries
105
+ self.pd = pd
106
+ self.np = np
107
+
108
+ # Conditionally import GPU libraries
109
+ if self.detector.use_gpu:
110
+ if CUDF_AVAILABLE:
111
+ self.cudf = cudf
112
+ self.cp = cp
113
+ self.py = cp # Use cupy for array operations
114
+ self.df = cudf # Use cudf for dataframes
115
+ logger.info("Successfully imported GPU libraries (cudf, cupy)")
116
+ else:
117
+ logger.error("cudf or cupy not found. Falling back to CPU mode.")
118
+ self.detector.use_gpu = False
119
+ self.use_gpu = False
120
+ self._setup_cpu_mode()
121
+ else:
122
+ self._setup_cpu_mode()
123
+
124
+ def _setup_cpu_mode(self):
125
+ """Setup CPU mode with numpy and pandas."""
126
+ self.py = self.np # Use numpy for array operations
127
+ self.df = self.pd # Use pandas for dataframes
128
+ self.normalize_vectors = False
129
+ self.metric_type = "COSINE"
130
+ logger.info("Using CPU mode with numpy and pandas")
131
+
132
+ def normalize_matrix(self, matrix, axis: int = 1):
133
+ """Normalize matrix using appropriate library."""
134
+ if not self.normalize_vectors:
135
+ return matrix
136
+
137
+ if self.use_gpu:
138
+ # Use cupy for GPU
139
+ matrix_cp = self.cp.asarray(matrix).astype(self.cp.float32)
140
+ norms = self.cp.linalg.norm(matrix_cp, axis=axis, keepdims=True)
141
+ return matrix_cp / norms
142
+ # CPU mode doesn't normalize for COSINE similarity
143
+ return matrix
144
+
145
+ def to_list(self, data):
146
+ """Convert data to list format."""
147
+ if hasattr(data, "tolist"):
148
+ return data.tolist()
149
+ if hasattr(data, "to_arrow"):
150
+ return data.to_arrow().to_pylist()
151
+ return list(data)
152
+
153
+
23
154
  class MultimodalPCSTPruning(NamedTuple):
24
155
  """
25
156
  Prize-Collecting Steiner Tree (PCST) pruning algorithm implementation inspired by G-Retriever
@@ -37,7 +168,11 @@ class MultimodalPCSTPruning(NamedTuple):
37
168
  num_clusters: The number of clusters.
38
169
  pruning: The pruning strategy to use.
39
170
  verbosity_level: The verbosity level.
171
+ use_description: Whether to use description embeddings.
172
+ metric_type: The similarity metric type (dynamic based on hardware).
173
+ loader: The dynamic library loader instance.
40
174
  """
175
+
41
176
  topk: int = 3
42
177
  topk_e: int = 3
43
178
  cost_e: float = 0.5
@@ -47,7 +182,8 @@ class MultimodalPCSTPruning(NamedTuple):
47
182
  pruning: str = "gw"
48
183
  verbosity_level: int = 0
49
184
  use_description: bool = False
50
- metric_type: str = "IP" # Inner Product
185
+ metric_type: str = None # Will be set dynamically
186
+ loader: DynamicLibraryLoader = None
51
187
 
52
188
  def prepare_collections(self, cfg: dict, modality: str) -> dict:
53
189
  """
@@ -81,11 +217,9 @@ class MultimodalPCSTPruning(NamedTuple):
81
217
 
82
218
  return colls
83
219
 
84
- def _compute_node_prizes(self,
85
- query_emb: list,
86
- colls: dict) -> dict:
220
+ def _compute_node_prizes(self, query_emb: list, colls: dict) -> dict:
87
221
  """
88
- Compute the node prizes based on the cosine similarity between the query and nodes.
222
+ Compute the node prizes based on the similarity between the query and nodes.
89
223
 
90
224
  Args:
91
225
  query_emb: The query embedding. This can be an embedding of
@@ -95,66 +229,73 @@ class MultimodalPCSTPruning(NamedTuple):
95
229
  Returns:
96
230
  The prizes of the nodes.
97
231
  """
98
- # Intialize several variables
232
+ # Initialize several variables
99
233
  topk = min(self.topk, colls["nodes"].num_entities)
100
- n_prizes = py.zeros(colls["nodes"].num_entities, dtype=py.float32)
234
+ n_prizes = self.loader.py.zeros(colls["nodes"].num_entities, dtype=self.loader.py.float32)
101
235
 
102
- # Calculate cosine similarity for text features and update the score
236
+ # Get the actual metric type to use
237
+ actual_metric_type = self.metric_type or self.loader.metric_type
238
+
239
+ # Calculate similarity for text features and update the score
103
240
  if self.use_description:
104
241
  # Search the collection with the text embedding
105
242
  res = colls["nodes"].search(
106
243
  data=[query_emb],
107
244
  anns_field="desc_emb",
108
- param={"metric_type": self.metric_type},
245
+ param={"metric_type": actual_metric_type},
109
246
  limit=topk,
110
- output_fields=["node_id"])
247
+ output_fields=["node_id"],
248
+ )
111
249
  else:
112
250
  # Search the collection with the query embedding
113
251
  res = colls["nodes_type"].search(
114
252
  data=[query_emb],
115
253
  anns_field="feat_emb",
116
- param={"metric_type": self.metric_type},
254
+ param={"metric_type": actual_metric_type},
117
255
  limit=topk,
118
- output_fields=["node_id"])
256
+ output_fields=["node_id"],
257
+ )
119
258
 
120
259
  # Update the prizes based on the search results
121
- n_prizes[[r.id for r in res[0]]] = py.arange(topk, 0, -1).astype(py.float32)
260
+ n_prizes[[r.id for r in res[0]]] = self.loader.py.arange(topk, 0, -1).astype(
261
+ self.loader.py.float32
262
+ )
122
263
 
123
264
  return n_prizes
124
265
 
125
- def _compute_edge_prizes(self,
126
- text_emb: list,
127
- colls: dict) -> py.ndarray:
266
+ def _compute_edge_prizes(self, text_emb: list, colls: dict):
128
267
  """
129
- Compute the node prizes based on the cosine similarity between the query and nodes.
268
+ Compute the edge prizes based on the similarity between the query and edges.
130
269
 
131
270
  Args:
132
271
  text_emb: The textual description embedding.
133
272
  colls: The collections of nodes, node-type specific nodes, and edges in Milvus.
134
273
 
135
274
  Returns:
136
- The prizes of the nodes.
275
+ The prizes of the edges.
137
276
  """
138
- # Intialize several variables
277
+ # Initialize several variables
139
278
  topk_e = min(self.topk_e, colls["edges"].num_entities)
140
- e_prizes = py.zeros(colls["edges"].num_entities, dtype=py.float32)
279
+ e_prizes = self.loader.py.zeros(colls["edges"].num_entities, dtype=self.loader.py.float32)
280
+
281
+ # Get the actual metric type to use
282
+ actual_metric_type = self.metric_type or self.loader.metric_type
141
283
 
142
284
  # Search the collection with the query embedding
143
285
  res = colls["edges"].search(
144
286
  data=[text_emb],
145
287
  anns_field="feat_emb",
146
- param={"metric_type": self.metric_type},
147
- limit=topk_e, # Only retrieve the top-k edges
148
- # limit=colls["edges"].num_entities,
149
- output_fields=["head_id", "tail_id"])
288
+ param={"metric_type": actual_metric_type},
289
+ limit=topk_e, # Only retrieve the top-k edges
290
+ output_fields=["head_id", "tail_id"],
291
+ )
150
292
 
151
293
  # Update the prizes based on the search results
152
294
  e_prizes[[r.id for r in res[0]]] = [r.score for r in res[0]]
153
295
 
154
296
  # Further process the edge_prizes
155
- unique_prizes, inverse_indices = py.unique(e_prizes, return_inverse=True)
156
- topk_e_values = unique_prizes[py.argsort(-unique_prizes)[:topk_e]]
157
- # e_prizes[e_prizes < topk_e_values[-1]] = 0.0
297
+ unique_prizes, inverse_indices = self.loader.py.unique(e_prizes, return_inverse=True)
298
+ topk_e_values = unique_prizes[self.loader.py.argsort(-unique_prizes)[:topk_e]]
158
299
  last_topk_e_value = topk_e
159
300
  for k in range(topk_e):
160
301
  indices = inverse_indices == (unique_prizes == topk_e_values[k]).nonzero()[0]
@@ -164,10 +305,7 @@ class MultimodalPCSTPruning(NamedTuple):
164
305
 
165
306
  return e_prizes
166
307
 
167
- def compute_prizes(self,
168
- text_emb: list,
169
- query_emb: list,
170
- colls: dict) -> dict:
308
+ def compute_prizes(self, text_emb: list, query_emb: list, colls: dict) -> dict:
171
309
  """
172
310
  Compute the node prizes based on the cosine similarity between the query and nodes,
173
311
  as well as the edge prizes based on the cosine similarity between the query and edges.
@@ -193,10 +331,7 @@ class MultimodalPCSTPruning(NamedTuple):
193
331
 
194
332
  return {"nodes": n_prizes, "edges": e_prizes}
195
333
 
196
- def compute_subgraph_costs(self,
197
- edge_index: py.ndarray,
198
- num_nodes: int,
199
- prizes: dict) -> Tuple[py.ndarray, py.ndarray, py.ndarray]:
334
+ def compute_subgraph_costs(self, edge_index, num_nodes: int, prizes: dict):
200
335
  """
201
336
  Compute the costs in constructing the subgraph proposed by G-Retriever paper.
202
337
 
@@ -218,7 +353,7 @@ class MultimodalPCSTPruning(NamedTuple):
218
353
  # Update edge cost threshold
219
354
  updated_cost_e = min(
220
355
  self.cost_e,
221
- py.max(prizes["edges"]).item() * (1 - self.c_const / 2),
356
+ self.loader.py.max(prizes["edges"]).item() * (1 - self.c_const / 2),
222
357
  )
223
358
 
224
359
  # Masks for real and virtual edges
@@ -228,19 +363,21 @@ class MultimodalPCSTPruning(NamedTuple):
228
363
 
229
364
  # Real edge indices
230
365
  logger.log(logging.INFO, "Computing real edges")
231
- real_["indices"] = py.nonzero(real_["mask"])[0]
366
+ real_["indices"] = self.loader.py.nonzero(real_["mask"])[0]
232
367
  real_["src"] = edge_index[0][real_["indices"]]
233
368
  real_["dst"] = edge_index[1][real_["indices"]]
234
- real_["edges"] = py.stack([real_["src"], real_["dst"]], axis=1)
369
+ real_["edges"] = self.loader.py.stack([real_["src"], real_["dst"]], axis=1)
235
370
  real_["costs"] = updated_cost_e - prizes["edges"][real_["indices"]]
236
371
 
237
372
  # Edge index mapping: local real edge idx -> original global index
238
373
  logger.log(logging.INFO, "Creating mapping for real edges")
239
- mapping_edges = dict(zip(range(len(real_["indices"])), real_["indices"].tolist()))
374
+ mapping_edges = dict(
375
+ zip(range(len(real_["indices"])), self.loader.to_list(real_["indices"]), strict=False)
376
+ )
240
377
 
241
378
  # Virtual edge handling
242
379
  logger.log(logging.INFO, "Computing virtual edges")
243
- virt_["indices"] = py.nonzero(virt_["mask"])[0]
380
+ virt_["indices"] = self.loader.py.nonzero(virt_["mask"])[0]
244
381
  virt_["src"] = edge_index[0][virt_["indices"]]
245
382
  virt_["dst"] = edge_index[1][virt_["indices"]]
246
383
  virt_["prizes"] = prizes["edges"][virt_["indices"]] - updated_cost_e
@@ -248,28 +385,35 @@ class MultimodalPCSTPruning(NamedTuple):
248
385
  # Generate virtual node IDs
249
386
  logger.log(logging.INFO, "Generating virtual node IDs")
250
387
  virt_["num"] = virt_["indices"].shape[0]
251
- virt_["node_ids"] = py.arange(num_nodes, num_nodes + virt_["num"])
388
+ virt_["node_ids"] = self.loader.py.arange(num_nodes, num_nodes + virt_["num"])
252
389
 
253
390
  # Virtual edges: (src → virtual), (virtual → dst)
254
391
  logger.log(logging.INFO, "Creating virtual edges")
255
- virt_["edges_1"] = py.stack([virt_["src"], virt_["node_ids"]], axis=1)
256
- virt_["edges_2"] = py.stack([virt_["node_ids"], virt_["dst"]], axis=1)
257
- virt_["edges"] = py.concatenate([virt_["edges_1"],
258
- virt_["edges_2"]], axis=0)
259
- virt_["costs"] = py.zeros((virt_["edges"].shape[0],), dtype=real_["costs"].dtype)
392
+ virt_["edges_1"] = self.loader.py.stack([virt_["src"], virt_["node_ids"]], axis=1)
393
+ virt_["edges_2"] = self.loader.py.stack([virt_["node_ids"], virt_["dst"]], axis=1)
394
+ virt_["edges"] = self.loader.py.concatenate([virt_["edges_1"], virt_["edges_2"]], axis=0)
395
+ virt_["costs"] = self.loader.py.zeros(
396
+ (virt_["edges"].shape[0],), dtype=real_["costs"].dtype
397
+ )
260
398
 
261
399
  # Combine real and virtual edges/costs
262
400
  logger.log(logging.INFO, "Combining real and virtual edges/costs")
263
- all_edges = py.concatenate([real_["edges"], virt_["edges"]], axis=0)
264
- all_costs = py.concatenate([real_["costs"], virt_["costs"]], axis=0)
401
+ all_edges = self.loader.py.concatenate([real_["edges"], virt_["edges"]], axis=0)
402
+ all_costs = self.loader.py.concatenate([real_["costs"], virt_["costs"]], axis=0)
265
403
 
266
404
  # Final prizes
267
405
  logger.log(logging.INFO, "Getting final prizes")
268
- final_prizes = py.concatenate([prizes["nodes"], virt_["prizes"]], axis=0)
406
+ final_prizes = self.loader.py.concatenate([prizes["nodes"], virt_["prizes"]], axis=0)
269
407
 
270
408
  # Mapping virtual node ID -> edge index in original graph
271
409
  logger.log(logging.INFO, "Creating mapping for virtual nodes")
272
- mapping_nodes = dict(zip(virt_["node_ids"].tolist(), virt_["indices"].tolist()))
410
+ mapping_nodes = dict(
411
+ zip(
412
+ self.loader.to_list(virt_["node_ids"]),
413
+ self.loader.to_list(virt_["indices"]),
414
+ strict=False,
415
+ )
416
+ )
273
417
 
274
418
  # Build return values
275
419
  logger.log(logging.INFO, "Building return values")
@@ -284,11 +428,9 @@ class MultimodalPCSTPruning(NamedTuple):
284
428
 
285
429
  return edges_dict, final_prizes, all_costs, mapping
286
430
 
287
- def get_subgraph_nodes_edges(self,
288
- num_nodes: int,
289
- vertices: py.ndarray,
290
- edges_dict: dict,
291
- mapping: dict) -> dict:
431
+ def get_subgraph_nodes_edges(
432
+ self, num_nodes: int, vertices, edges_dict: dict, mapping: dict
433
+ ) -> dict:
292
434
  """
293
435
  Get the selected nodes and edges of the subgraph based on the vertices and edges computed
294
436
  by the PCST algorithm.
@@ -305,31 +447,22 @@ class MultimodalPCSTPruning(NamedTuple):
305
447
  # Get edges information
306
448
  edges = edges_dict["edges"]
307
449
  num_prior_edges = edges_dict["num_prior_edges"]
308
- # Get edges information
309
- edges = edges_dict["edges"]
310
- num_prior_edges = edges_dict["num_prior_edges"]
450
+
311
451
  # Retrieve the selected nodes and edges based on the given vertices and edges
312
452
  subgraph_nodes = vertices[vertices < num_nodes]
313
453
  subgraph_edges = [mapping["edges"][e.item()] for e in edges if e < num_prior_edges]
314
454
  virtual_vertices = vertices[vertices >= num_nodes]
315
455
  if len(virtual_vertices) > 0:
316
- virtual_vertices = vertices[vertices >= num_nodes]
317
456
  virtual_edges = [mapping["nodes"][i.item()] for i in virtual_vertices]
318
- subgraph_edges = py.array(subgraph_edges + virtual_edges)
457
+ subgraph_edges = self.loader.py.array(subgraph_edges + virtual_edges)
319
458
  edge_index = edges_dict["edge_index"][:, subgraph_edges]
320
- subgraph_nodes = py.unique(
321
- py.concatenate(
322
- [subgraph_nodes, edge_index[0], edge_index[1]]
323
- )
459
+ subgraph_nodes = self.loader.py.unique(
460
+ self.loader.py.concatenate([subgraph_nodes, edge_index[0], edge_index[1]])
324
461
  )
325
462
 
326
463
  return {"nodes": subgraph_nodes, "edges": subgraph_edges}
327
464
 
328
- def extract_subgraph(self,
329
- text_emb: list,
330
- query_emb: list,
331
- modality: str,
332
- cfg: dict) -> dict:
465
+ def extract_subgraph(self, text_emb: list, query_emb: list, modality: str, cfg: dict) -> dict:
333
466
  """
334
467
  Perform the Prize-Collecting Steiner Tree (PCST) algorithm to extract the subgraph.
335
468
 
@@ -352,7 +485,7 @@ class MultimodalPCSTPruning(NamedTuple):
352
485
  logger.log(logging.INFO, "Loading cache edge index")
353
486
  with open(cfg.milvus_db.cache_edge_index_path, "rb") as f:
354
487
  edge_index = pickle.load(f)
355
- edge_index = py.array(edge_index)
488
+ edge_index = self.loader.py.array(edge_index)
356
489
 
357
490
  # Assert the topk and topk_e values for subgraph retrieval
358
491
  assert self.topk > 0, "topk must be greater than or equal to 0"
@@ -365,7 +498,8 @@ class MultimodalPCSTPruning(NamedTuple):
365
498
  # Compute costs in constructing the subgraph
366
499
  logger.log(logging.INFO, "compute_subgraph_costs")
367
500
  edges_dict, prizes, costs, mapping = self.compute_subgraph_costs(
368
- edge_index, colls["nodes"].num_entities, prizes)
501
+ edge_index, colls["nodes"].num_entities, prizes
502
+ )
369
503
 
370
504
  # Retrieve the subgraph using the PCST algorithm
371
505
  logger.log(logging.INFO, "Running PCST algorithm")
@@ -383,11 +517,14 @@ class MultimodalPCSTPruning(NamedTuple):
383
517
  logger.log(logging.INFO, "Getting subgraph nodes and edges")
384
518
  subgraph = self.get_subgraph_nodes_edges(
385
519
  colls["nodes"].num_entities,
386
- py.asarray(result_vertices),
387
- {"edges": py.asarray(result_edges),
388
- "num_prior_edges": edges_dict["num_prior_edges"],
389
- "edge_index": edge_index},
390
- mapping)
520
+ self.loader.py.asarray(result_vertices),
521
+ {
522
+ "edges": self.loader.py.asarray(result_edges),
523
+ "num_prior_edges": edges_dict["num_prior_edges"],
524
+ "edge_index": edge_index,
525
+ },
526
+ mapping,
527
+ )
391
528
  print(subgraph)
392
529
 
393
530
  return subgraph
@@ -2,13 +2,15 @@
2
2
  Exctraction of multimodal subgraph using Prize-Collecting Steiner Tree (PCST) algorithm.
3
3
  """
4
4
 
5
- from typing import Tuple, NamedTuple
5
+ from typing import NamedTuple
6
+
6
7
  import numpy as np
7
8
  import pandas as pd
8
- import torch
9
9
  import pcst_fast
10
+ import torch
10
11
  from torch_geometric.data.data import Data
11
12
 
13
+
12
14
  class MultimodalPCSTPruning(NamedTuple):
13
15
  """
14
16
  Prize-Collecting Steiner Tree (PCST) pruning algorithm implementation inspired by G-Retriever
@@ -27,6 +29,7 @@ class MultimodalPCSTPruning(NamedTuple):
27
29
  pruning: The pruning strategy to use.
28
30
  verbosity_level: The verbosity level.
29
31
  """
32
+
30
33
  topk: int = 3
31
34
  topk_e: int = 3
32
35
  cost_e: float = 0.5
@@ -37,10 +40,7 @@ class MultimodalPCSTPruning(NamedTuple):
37
40
  verbosity_level: int = 0
38
41
  use_description: bool = False
39
42
 
40
- def _compute_node_prizes(self,
41
- graph: Data,
42
- query_emb: torch.Tensor,
43
- modality: str) :
43
+ def _compute_node_prizes(self, graph: Data, query_emb: torch.Tensor, modality: str):
44
44
  """
45
45
  Compute the node prizes based on the cosine similarity between the query and nodes.
46
46
 
@@ -54,25 +54,28 @@ class MultimodalPCSTPruning(NamedTuple):
54
54
  The prizes of the nodes.
55
55
  """
56
56
  # Convert PyG graph to a DataFrame
57
- graph_df = pd.DataFrame({
58
- "node_type": graph.node_type,
59
- "desc_x": [x.tolist() for x in graph.desc_x],
60
- "x": [list(x) for x in graph.x],
61
- "score": [0.0 for _ in range(len(graph.node_id))],
62
- })
57
+ graph_df = pd.DataFrame(
58
+ {
59
+ "node_type": graph.node_type,
60
+ "desc_x": [x.tolist() for x in graph.desc_x],
61
+ "x": [list(x) for x in graph.x],
62
+ "score": [0.0 for _ in range(len(graph.node_id))],
63
+ }
64
+ )
63
65
 
64
66
  # Calculate cosine similarity for text features and update the score
65
67
  if self.use_description:
66
68
  graph_df.loc[:, "score"] = torch.nn.CosineSimilarity(dim=-1)(
67
- query_emb,
68
- torch.tensor(list(graph_df.desc_x.values)) # Using textual description features
69
- ).tolist()
69
+ query_emb,
70
+ torch.tensor(list(graph_df.desc_x.values)), # Using textual description features
71
+ ).tolist()
70
72
  else:
71
- graph_df.loc[graph_df["node_type"] == modality,
72
- "score"] = torch.nn.CosineSimilarity(dim=-1)(
73
- query_emb,
74
- torch.tensor(list(graph_df[graph_df["node_type"]== modality].x.values))
75
- ).tolist()
73
+ graph_df.loc[graph_df["node_type"] == modality, "score"] = torch.nn.CosineSimilarity(
74
+ dim=-1
75
+ )(
76
+ query_emb,
77
+ torch.tensor(list(graph_df[graph_df["node_type"] == modality].x.values)),
78
+ ).tolist()
76
79
 
77
80
  # Set the prizes for nodes based on the similarity scores
78
81
  n_prizes = torch.tensor(graph_df.score.values, dtype=torch.float32)
@@ -84,9 +87,7 @@ class MultimodalPCSTPruning(NamedTuple):
84
87
 
85
88
  return n_prizes
86
89
 
87
- def _compute_edge_prizes(self,
88
- graph: Data,
89
- text_emb: torch.Tensor) :
90
+ def _compute_edge_prizes(self, graph: Data, text_emb: torch.Tensor):
90
91
  """
91
92
  Compute the node prizes based on the cosine similarity between the query and nodes.
92
93
 
@@ -106,20 +107,22 @@ class MultimodalPCSTPruning(NamedTuple):
106
107
  e_prizes[e_prizes < topk_e_values[-1]] = 0.0
107
108
  last_topk_e_value = topk_e
108
109
  for k in range(topk_e):
109
- indices = inverse_indices == (
110
- unique_prizes == topk_e_values[k]
111
- ).nonzero(as_tuple=True)[0]
110
+ indices = (
111
+ inverse_indices == (unique_prizes == topk_e_values[k]).nonzero(as_tuple=True)[0]
112
+ )
112
113
  value = min((topk_e - k) / indices.sum().item(), last_topk_e_value)
113
114
  e_prizes[indices] = value
114
115
  last_topk_e_value = value * (1 - self.c_const)
115
116
 
116
117
  return e_prizes
117
118
 
118
- def compute_prizes(self,
119
- graph: Data,
120
- text_emb: torch.Tensor,
121
- query_emb: torch.Tensor,
122
- modality: str):
119
+ def compute_prizes(
120
+ self,
121
+ graph: Data,
122
+ text_emb: torch.Tensor,
123
+ query_emb: torch.Tensor,
124
+ modality: str,
125
+ ):
123
126
  """
124
127
  Compute the node prizes based on the cosine similarity between the query and nodes,
125
128
  as well as the edge prizes based on the cosine similarity between the query and edges.
@@ -144,9 +147,9 @@ class MultimodalPCSTPruning(NamedTuple):
144
147
 
145
148
  return {"nodes": n_prizes, "edges": e_prizes}
146
149
 
147
- def compute_subgraph_costs(self,
148
- graph: Data,
149
- prizes: dict) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
150
+ def compute_subgraph_costs(
151
+ self, graph: Data, prizes: dict
152
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
150
153
  """
151
154
  Compute the costs in constructing the subgraph proposed by G-Retriever paper.
152
155
 
@@ -204,7 +207,11 @@ class MultimodalPCSTPruning(NamedTuple):
204
207
  return edges_dict, prizes, costs, mapping
205
208
 
206
209
  def get_subgraph_nodes_edges(
207
- self, graph: Data, vertices: np.ndarray, edges_dict: dict, mapping: dict,
210
+ self,
211
+ graph: Data,
212
+ vertices: np.ndarray,
213
+ edges_dict: dict,
214
+ mapping: dict,
208
215
  ) -> dict:
209
216
  """
210
217
  Get the selected nodes and edges of the subgraph based on the vertices and edges computed
@@ -234,18 +241,18 @@ class MultimodalPCSTPruning(NamedTuple):
234
241
  subgraph_edges = np.array(subgraph_edges + virtual_edges)
235
242
  edge_index = graph.edge_index[:, subgraph_edges]
236
243
  subgraph_nodes = np.unique(
237
- np.concatenate(
238
- [subgraph_nodes, edge_index[0].numpy(), edge_index[1].numpy()]
239
- )
244
+ np.concatenate([subgraph_nodes, edge_index[0].numpy(), edge_index[1].numpy()])
240
245
  )
241
246
 
242
247
  return {"nodes": subgraph_nodes, "edges": subgraph_edges}
243
248
 
244
- def extract_subgraph(self,
245
- graph: Data,
246
- text_emb: torch.Tensor,
247
- query_emb: torch.Tensor,
248
- modality: str) -> dict:
249
+ def extract_subgraph(
250
+ self,
251
+ graph: Data,
252
+ text_emb: torch.Tensor,
253
+ query_emb: torch.Tensor,
254
+ modality: str,
255
+ ) -> dict:
249
256
  """
250
257
  Perform the Prize-Collecting Steiner Tree (PCST) algorithm to extract the subgraph.
251
258
 
@@ -268,9 +275,7 @@ class MultimodalPCSTPruning(NamedTuple):
268
275
  prizes = self.compute_prizes(graph, text_emb, query_emb, modality)
269
276
 
270
277
  # Compute costs in constructing the subgraph
271
- edges_dict, prizes, costs, mapping = self.compute_subgraph_costs(
272
- graph, prizes
273
- )
278
+ edges_dict, prizes, costs, mapping = self.compute_subgraph_costs(graph, prizes)
274
279
 
275
280
  # Retrieve the subgraph using the PCST algorithm
276
281
  result_vertices, result_edges = pcst_fast.pcst_fast(
@@ -287,6 +292,7 @@ class MultimodalPCSTPruning(NamedTuple):
287
292
  graph,
288
293
  result_vertices,
289
294
  {"edges": result_edges, "num_prior_edges": edges_dict["num_prior_edges"]},
290
- mapping)
295
+ mapping,
296
+ )
291
297
 
292
298
  return subgraph