aiagents4pharma 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (336) hide show
  1. aiagents4pharma/__init__.py +11 -0
  2. aiagents4pharma/talk2aiagents4pharma/.dockerignore +13 -0
  3. aiagents4pharma/talk2aiagents4pharma/Dockerfile +133 -0
  4. aiagents4pharma/talk2aiagents4pharma/README.md +1 -0
  5. aiagents4pharma/talk2aiagents4pharma/__init__.py +5 -0
  6. aiagents4pharma/talk2aiagents4pharma/agents/__init__.py +6 -0
  7. aiagents4pharma/talk2aiagents4pharma/agents/main_agent.py +70 -0
  8. aiagents4pharma/talk2aiagents4pharma/configs/__init__.py +5 -0
  9. aiagents4pharma/talk2aiagents4pharma/configs/agents/__init__.py +5 -0
  10. aiagents4pharma/talk2aiagents4pharma/configs/agents/main_agent/default.yaml +29 -0
  11. aiagents4pharma/talk2aiagents4pharma/configs/app/__init__.py +0 -0
  12. aiagents4pharma/talk2aiagents4pharma/configs/app/frontend/__init__.py +0 -0
  13. aiagents4pharma/talk2aiagents4pharma/configs/app/frontend/default.yaml +102 -0
  14. aiagents4pharma/talk2aiagents4pharma/configs/config.yaml +4 -0
  15. aiagents4pharma/talk2aiagents4pharma/docker-compose/cpu/.env.example +23 -0
  16. aiagents4pharma/talk2aiagents4pharma/docker-compose/cpu/docker-compose.yml +93 -0
  17. aiagents4pharma/talk2aiagents4pharma/docker-compose/gpu/.env.example +23 -0
  18. aiagents4pharma/talk2aiagents4pharma/docker-compose/gpu/docker-compose.yml +108 -0
  19. aiagents4pharma/talk2aiagents4pharma/install.md +154 -0
  20. aiagents4pharma/talk2aiagents4pharma/states/__init__.py +5 -0
  21. aiagents4pharma/talk2aiagents4pharma/states/state_talk2aiagents4pharma.py +18 -0
  22. aiagents4pharma/talk2aiagents4pharma/tests/__init__.py +3 -0
  23. aiagents4pharma/talk2aiagents4pharma/tests/test_main_agent.py +312 -0
  24. aiagents4pharma/talk2biomodels/.dockerignore +13 -0
  25. aiagents4pharma/talk2biomodels/Dockerfile +104 -0
  26. aiagents4pharma/talk2biomodels/README.md +1 -0
  27. aiagents4pharma/talk2biomodels/__init__.py +5 -0
  28. aiagents4pharma/talk2biomodels/agents/__init__.py +6 -0
  29. aiagents4pharma/talk2biomodels/agents/t2b_agent.py +104 -0
  30. aiagents4pharma/talk2biomodels/api/__init__.py +5 -0
  31. aiagents4pharma/talk2biomodels/api/ols.py +75 -0
  32. aiagents4pharma/talk2biomodels/api/uniprot.py +36 -0
  33. aiagents4pharma/talk2biomodels/configs/__init__.py +5 -0
  34. aiagents4pharma/talk2biomodels/configs/agents/__init__.py +5 -0
  35. aiagents4pharma/talk2biomodels/configs/agents/t2b_agent/__init__.py +3 -0
  36. aiagents4pharma/talk2biomodels/configs/agents/t2b_agent/default.yaml +14 -0
  37. aiagents4pharma/talk2biomodels/configs/app/__init__.py +0 -0
  38. aiagents4pharma/talk2biomodels/configs/app/frontend/__init__.py +0 -0
  39. aiagents4pharma/talk2biomodels/configs/app/frontend/default.yaml +72 -0
  40. aiagents4pharma/talk2biomodels/configs/config.yaml +7 -0
  41. aiagents4pharma/talk2biomodels/configs/tools/__init__.py +5 -0
  42. aiagents4pharma/talk2biomodels/configs/tools/ask_question/__init__.py +3 -0
  43. aiagents4pharma/talk2biomodels/configs/tools/ask_question/default.yaml +30 -0
  44. aiagents4pharma/talk2biomodels/configs/tools/custom_plotter/__init__.py +3 -0
  45. aiagents4pharma/talk2biomodels/configs/tools/custom_plotter/default.yaml +8 -0
  46. aiagents4pharma/talk2biomodels/configs/tools/get_annotation/__init__.py +3 -0
  47. aiagents4pharma/talk2biomodels/configs/tools/get_annotation/default.yaml +8 -0
  48. aiagents4pharma/talk2biomodels/install.md +63 -0
  49. aiagents4pharma/talk2biomodels/models/__init__.py +5 -0
  50. aiagents4pharma/talk2biomodels/models/basico_model.py +125 -0
  51. aiagents4pharma/talk2biomodels/models/sys_bio_model.py +60 -0
  52. aiagents4pharma/talk2biomodels/states/__init__.py +6 -0
  53. aiagents4pharma/talk2biomodels/states/state_talk2biomodels.py +49 -0
  54. aiagents4pharma/talk2biomodels/tests/BIOMD0000000449_url.xml +1585 -0
  55. aiagents4pharma/talk2biomodels/tests/__init__.py +3 -0
  56. aiagents4pharma/talk2biomodels/tests/article_on_model_537.pdf +0 -0
  57. aiagents4pharma/talk2biomodels/tests/test_api.py +31 -0
  58. aiagents4pharma/talk2biomodels/tests/test_ask_question.py +42 -0
  59. aiagents4pharma/talk2biomodels/tests/test_basico_model.py +67 -0
  60. aiagents4pharma/talk2biomodels/tests/test_get_annotation.py +190 -0
  61. aiagents4pharma/talk2biomodels/tests/test_getmodelinfo.py +92 -0
  62. aiagents4pharma/talk2biomodels/tests/test_integration.py +116 -0
  63. aiagents4pharma/talk2biomodels/tests/test_load_biomodel.py +35 -0
  64. aiagents4pharma/talk2biomodels/tests/test_param_scan.py +71 -0
  65. aiagents4pharma/talk2biomodels/tests/test_query_article.py +184 -0
  66. aiagents4pharma/talk2biomodels/tests/test_save_model.py +47 -0
  67. aiagents4pharma/talk2biomodels/tests/test_search_models.py +35 -0
  68. aiagents4pharma/talk2biomodels/tests/test_simulate_model.py +44 -0
  69. aiagents4pharma/talk2biomodels/tests/test_steady_state.py +86 -0
  70. aiagents4pharma/talk2biomodels/tests/test_sys_bio_model.py +67 -0
  71. aiagents4pharma/talk2biomodels/tools/__init__.py +17 -0
  72. aiagents4pharma/talk2biomodels/tools/ask_question.py +125 -0
  73. aiagents4pharma/talk2biomodels/tools/custom_plotter.py +165 -0
  74. aiagents4pharma/talk2biomodels/tools/get_annotation.py +342 -0
  75. aiagents4pharma/talk2biomodels/tools/get_modelinfo.py +159 -0
  76. aiagents4pharma/talk2biomodels/tools/load_arguments.py +134 -0
  77. aiagents4pharma/talk2biomodels/tools/load_biomodel.py +44 -0
  78. aiagents4pharma/talk2biomodels/tools/parameter_scan.py +310 -0
  79. aiagents4pharma/talk2biomodels/tools/query_article.py +64 -0
  80. aiagents4pharma/talk2biomodels/tools/save_model.py +98 -0
  81. aiagents4pharma/talk2biomodels/tools/search_models.py +96 -0
  82. aiagents4pharma/talk2biomodels/tools/simulate_model.py +137 -0
  83. aiagents4pharma/talk2biomodels/tools/steady_state.py +187 -0
  84. aiagents4pharma/talk2biomodels/tools/utils.py +23 -0
  85. aiagents4pharma/talk2cells/README.md +1 -0
  86. aiagents4pharma/talk2cells/__init__.py +5 -0
  87. aiagents4pharma/talk2cells/agents/__init__.py +6 -0
  88. aiagents4pharma/talk2cells/agents/scp_agent.py +87 -0
  89. aiagents4pharma/talk2cells/states/__init__.py +6 -0
  90. aiagents4pharma/talk2cells/states/state_talk2cells.py +15 -0
  91. aiagents4pharma/talk2cells/tests/scp_agent/test_scp_agent.py +22 -0
  92. aiagents4pharma/talk2cells/tools/__init__.py +6 -0
  93. aiagents4pharma/talk2cells/tools/scp_agent/__init__.py +6 -0
  94. aiagents4pharma/talk2cells/tools/scp_agent/display_studies.py +27 -0
  95. aiagents4pharma/talk2cells/tools/scp_agent/search_studies.py +78 -0
  96. aiagents4pharma/talk2knowledgegraphs/.dockerignore +13 -0
  97. aiagents4pharma/talk2knowledgegraphs/Dockerfile +131 -0
  98. aiagents4pharma/talk2knowledgegraphs/README.md +1 -0
  99. aiagents4pharma/talk2knowledgegraphs/__init__.py +5 -0
  100. aiagents4pharma/talk2knowledgegraphs/agents/__init__.py +5 -0
  101. aiagents4pharma/talk2knowledgegraphs/agents/t2kg_agent.py +99 -0
  102. aiagents4pharma/talk2knowledgegraphs/configs/__init__.py +5 -0
  103. aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/__init__.py +3 -0
  104. aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/default.yaml +62 -0
  105. aiagents4pharma/talk2knowledgegraphs/configs/app/__init__.py +5 -0
  106. aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/__init__.py +3 -0
  107. aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/default.yaml +79 -0
  108. aiagents4pharma/talk2knowledgegraphs/configs/config.yaml +13 -0
  109. aiagents4pharma/talk2knowledgegraphs/configs/tools/__init__.py +5 -0
  110. aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/__init__.py +3 -0
  111. aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/default.yaml +24 -0
  112. aiagents4pharma/talk2knowledgegraphs/configs/tools/multimodal_subgraph_extraction/__init__.py +0 -0
  113. aiagents4pharma/talk2knowledgegraphs/configs/tools/multimodal_subgraph_extraction/default.yaml +33 -0
  114. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/__init__.py +3 -0
  115. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/default.yaml +43 -0
  116. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/__init__.py +3 -0
  117. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/default.yaml +9 -0
  118. aiagents4pharma/talk2knowledgegraphs/configs/utils/database/milvus/__init__.py +3 -0
  119. aiagents4pharma/talk2knowledgegraphs/configs/utils/database/milvus/default.yaml +61 -0
  120. aiagents4pharma/talk2knowledgegraphs/configs/utils/enrichments/ols_terms/default.yaml +3 -0
  121. aiagents4pharma/talk2knowledgegraphs/configs/utils/enrichments/reactome_pathways/default.yaml +3 -0
  122. aiagents4pharma/talk2knowledgegraphs/configs/utils/enrichments/uniprot_proteins/default.yaml +6 -0
  123. aiagents4pharma/talk2knowledgegraphs/configs/utils/pubchem_utils/default.yaml +5 -0
  124. aiagents4pharma/talk2knowledgegraphs/datasets/__init__.py +5 -0
  125. aiagents4pharma/talk2knowledgegraphs/datasets/biobridge_primekg.py +607 -0
  126. aiagents4pharma/talk2knowledgegraphs/datasets/dataset.py +25 -0
  127. aiagents4pharma/talk2knowledgegraphs/datasets/primekg.py +212 -0
  128. aiagents4pharma/talk2knowledgegraphs/datasets/starkqa_primekg.py +210 -0
  129. aiagents4pharma/talk2knowledgegraphs/docker-compose/cpu/.env.example +23 -0
  130. aiagents4pharma/talk2knowledgegraphs/docker-compose/cpu/docker-compose.yml +93 -0
  131. aiagents4pharma/talk2knowledgegraphs/docker-compose/gpu/.env.example +23 -0
  132. aiagents4pharma/talk2knowledgegraphs/docker-compose/gpu/docker-compose.yml +108 -0
  133. aiagents4pharma/talk2knowledgegraphs/entrypoint.sh +180 -0
  134. aiagents4pharma/talk2knowledgegraphs/install.md +165 -0
  135. aiagents4pharma/talk2knowledgegraphs/milvus_data_dump.py +886 -0
  136. aiagents4pharma/talk2knowledgegraphs/states/__init__.py +5 -0
  137. aiagents4pharma/talk2knowledgegraphs/states/state_talk2knowledgegraphs.py +40 -0
  138. aiagents4pharma/talk2knowledgegraphs/tests/__init__.py +0 -0
  139. aiagents4pharma/talk2knowledgegraphs/tests/test_agents_t2kg_agent.py +318 -0
  140. aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_biobridge_primekg.py +248 -0
  141. aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_dataset.py +33 -0
  142. aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_primekg.py +86 -0
  143. aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_starkqa_primekg.py +125 -0
  144. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_graphrag_reasoning.py +257 -0
  145. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_milvus_multimodal_subgraph_extraction.py +1444 -0
  146. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_multimodal_subgraph_extraction.py +159 -0
  147. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_extraction.py +152 -0
  148. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_summarization.py +201 -0
  149. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_database_milvus_connection_manager.py +812 -0
  150. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_embeddings.py +51 -0
  151. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_huggingface.py +49 -0
  152. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_nim_molmim.py +59 -0
  153. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_ollama.py +63 -0
  154. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_sentencetransformer.py +47 -0
  155. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_enrichments.py +40 -0
  156. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ollama.py +94 -0
  157. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ols.py +70 -0
  158. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_pubchem.py +45 -0
  159. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_reactome.py +44 -0
  160. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_uniprot.py +48 -0
  161. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_extractions_milvus_multimodal_pcst.py +759 -0
  162. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_kg_utils.py +78 -0
  163. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_pubchem_utils.py +123 -0
  164. aiagents4pharma/talk2knowledgegraphs/tools/__init__.py +11 -0
  165. aiagents4pharma/talk2knowledgegraphs/tools/graphrag_reasoning.py +138 -0
  166. aiagents4pharma/talk2knowledgegraphs/tools/load_arguments.py +22 -0
  167. aiagents4pharma/talk2knowledgegraphs/tools/milvus_multimodal_subgraph_extraction.py +965 -0
  168. aiagents4pharma/talk2knowledgegraphs/tools/multimodal_subgraph_extraction.py +374 -0
  169. aiagents4pharma/talk2knowledgegraphs/tools/subgraph_extraction.py +291 -0
  170. aiagents4pharma/talk2knowledgegraphs/tools/subgraph_summarization.py +123 -0
  171. aiagents4pharma/talk2knowledgegraphs/utils/__init__.py +5 -0
  172. aiagents4pharma/talk2knowledgegraphs/utils/database/__init__.py +5 -0
  173. aiagents4pharma/talk2knowledgegraphs/utils/database/milvus_connection_manager.py +586 -0
  174. aiagents4pharma/talk2knowledgegraphs/utils/embeddings/__init__.py +5 -0
  175. aiagents4pharma/talk2knowledgegraphs/utils/embeddings/embeddings.py +81 -0
  176. aiagents4pharma/talk2knowledgegraphs/utils/embeddings/huggingface.py +111 -0
  177. aiagents4pharma/talk2knowledgegraphs/utils/embeddings/nim_molmim.py +54 -0
  178. aiagents4pharma/talk2knowledgegraphs/utils/embeddings/ollama.py +87 -0
  179. aiagents4pharma/talk2knowledgegraphs/utils/embeddings/sentence_transformer.py +73 -0
  180. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/__init__.py +12 -0
  181. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/enrichments.py +37 -0
  182. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/ollama.py +129 -0
  183. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/ols_terms.py +89 -0
  184. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/pubchem_strings.py +78 -0
  185. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/reactome_pathways.py +71 -0
  186. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/uniprot_proteins.py +98 -0
  187. aiagents4pharma/talk2knowledgegraphs/utils/extractions/__init__.py +5 -0
  188. aiagents4pharma/talk2knowledgegraphs/utils/extractions/milvus_multimodal_pcst.py +762 -0
  189. aiagents4pharma/talk2knowledgegraphs/utils/extractions/multimodal_pcst.py +298 -0
  190. aiagents4pharma/talk2knowledgegraphs/utils/extractions/pcst.py +229 -0
  191. aiagents4pharma/talk2knowledgegraphs/utils/kg_utils.py +67 -0
  192. aiagents4pharma/talk2knowledgegraphs/utils/pubchem_utils.py +104 -0
  193. aiagents4pharma/talk2scholars/.dockerignore +13 -0
  194. aiagents4pharma/talk2scholars/Dockerfile +104 -0
  195. aiagents4pharma/talk2scholars/README.md +1 -0
  196. aiagents4pharma/talk2scholars/__init__.py +7 -0
  197. aiagents4pharma/talk2scholars/agents/__init__.py +13 -0
  198. aiagents4pharma/talk2scholars/agents/main_agent.py +89 -0
  199. aiagents4pharma/talk2scholars/agents/paper_download_agent.py +96 -0
  200. aiagents4pharma/talk2scholars/agents/pdf_agent.py +101 -0
  201. aiagents4pharma/talk2scholars/agents/s2_agent.py +135 -0
  202. aiagents4pharma/talk2scholars/agents/zotero_agent.py +127 -0
  203. aiagents4pharma/talk2scholars/configs/__init__.py +7 -0
  204. aiagents4pharma/talk2scholars/configs/agents/__init__.py +7 -0
  205. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/__init__.py +7 -0
  206. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/main_agent/__init__.py +3 -0
  207. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/main_agent/default.yaml +52 -0
  208. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/paper_download_agent/__init__.py +3 -0
  209. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/paper_download_agent/default.yaml +19 -0
  210. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/pdf_agent/__init__.py +3 -0
  211. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/pdf_agent/default.yaml +19 -0
  212. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/s2_agent/__init__.py +3 -0
  213. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/s2_agent/default.yaml +44 -0
  214. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/zotero_agent/__init__.py +3 -0
  215. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/zotero_agent/default.yaml +19 -0
  216. aiagents4pharma/talk2scholars/configs/app/__init__.py +7 -0
  217. aiagents4pharma/talk2scholars/configs/app/frontend/__init__.py +3 -0
  218. aiagents4pharma/talk2scholars/configs/app/frontend/default.yaml +72 -0
  219. aiagents4pharma/talk2scholars/configs/config.yaml +16 -0
  220. aiagents4pharma/talk2scholars/configs/tools/__init__.py +21 -0
  221. aiagents4pharma/talk2scholars/configs/tools/multi_paper_recommendation/__init__.py +3 -0
  222. aiagents4pharma/talk2scholars/configs/tools/multi_paper_recommendation/default.yaml +26 -0
  223. aiagents4pharma/talk2scholars/configs/tools/paper_download/__init__.py +3 -0
  224. aiagents4pharma/talk2scholars/configs/tools/paper_download/default.yaml +124 -0
  225. aiagents4pharma/talk2scholars/configs/tools/question_and_answer/__init__.py +3 -0
  226. aiagents4pharma/talk2scholars/configs/tools/question_and_answer/default.yaml +62 -0
  227. aiagents4pharma/talk2scholars/configs/tools/retrieve_semantic_scholar_paper_id/__init__.py +3 -0
  228. aiagents4pharma/talk2scholars/configs/tools/retrieve_semantic_scholar_paper_id/default.yaml +12 -0
  229. aiagents4pharma/talk2scholars/configs/tools/search/__init__.py +3 -0
  230. aiagents4pharma/talk2scholars/configs/tools/search/default.yaml +26 -0
  231. aiagents4pharma/talk2scholars/configs/tools/single_paper_recommendation/__init__.py +3 -0
  232. aiagents4pharma/talk2scholars/configs/tools/single_paper_recommendation/default.yaml +26 -0
  233. aiagents4pharma/talk2scholars/configs/tools/zotero_read/__init__.py +3 -0
  234. aiagents4pharma/talk2scholars/configs/tools/zotero_read/default.yaml +57 -0
  235. aiagents4pharma/talk2scholars/configs/tools/zotero_write/__inti__.py +3 -0
  236. aiagents4pharma/talk2scholars/configs/tools/zotero_write/default.yaml +55 -0
  237. aiagents4pharma/talk2scholars/docker-compose/cpu/.env.example +21 -0
  238. aiagents4pharma/talk2scholars/docker-compose/cpu/docker-compose.yml +90 -0
  239. aiagents4pharma/talk2scholars/docker-compose/gpu/.env.example +21 -0
  240. aiagents4pharma/talk2scholars/docker-compose/gpu/docker-compose.yml +105 -0
  241. aiagents4pharma/talk2scholars/install.md +122 -0
  242. aiagents4pharma/talk2scholars/state/__init__.py +7 -0
  243. aiagents4pharma/talk2scholars/state/state_talk2scholars.py +98 -0
  244. aiagents4pharma/talk2scholars/tests/__init__.py +3 -0
  245. aiagents4pharma/talk2scholars/tests/test_agents_main_agent.py +256 -0
  246. aiagents4pharma/talk2scholars/tests/test_agents_paper_agents_download_agent.py +139 -0
  247. aiagents4pharma/talk2scholars/tests/test_agents_pdf_agent.py +114 -0
  248. aiagents4pharma/talk2scholars/tests/test_agents_s2_agent.py +198 -0
  249. aiagents4pharma/talk2scholars/tests/test_agents_zotero_agent.py +160 -0
  250. aiagents4pharma/talk2scholars/tests/test_s2_tools_display_dataframe.py +91 -0
  251. aiagents4pharma/talk2scholars/tests/test_s2_tools_query_dataframe.py +191 -0
  252. aiagents4pharma/talk2scholars/tests/test_states_state.py +38 -0
  253. aiagents4pharma/talk2scholars/tests/test_tools_paper_downloader.py +507 -0
  254. aiagents4pharma/talk2scholars/tests/test_tools_question_and_answer_tool.py +105 -0
  255. aiagents4pharma/talk2scholars/tests/test_tools_s2_multi.py +307 -0
  256. aiagents4pharma/talk2scholars/tests/test_tools_s2_retrieve.py +67 -0
  257. aiagents4pharma/talk2scholars/tests/test_tools_s2_search.py +286 -0
  258. aiagents4pharma/talk2scholars/tests/test_tools_s2_single.py +298 -0
  259. aiagents4pharma/talk2scholars/tests/test_utils_arxiv_downloader.py +469 -0
  260. aiagents4pharma/talk2scholars/tests/test_utils_base_paper_downloader.py +598 -0
  261. aiagents4pharma/talk2scholars/tests/test_utils_biorxiv_downloader.py +669 -0
  262. aiagents4pharma/talk2scholars/tests/test_utils_medrxiv_downloader.py +500 -0
  263. aiagents4pharma/talk2scholars/tests/test_utils_nvidia_nim_reranker.py +117 -0
  264. aiagents4pharma/talk2scholars/tests/test_utils_pdf_answer_formatter.py +67 -0
  265. aiagents4pharma/talk2scholars/tests/test_utils_pdf_batch_processor.py +92 -0
  266. aiagents4pharma/talk2scholars/tests/test_utils_pdf_collection_manager.py +173 -0
  267. aiagents4pharma/talk2scholars/tests/test_utils_pdf_document_processor.py +68 -0
  268. aiagents4pharma/talk2scholars/tests/test_utils_pdf_generate_answer.py +72 -0
  269. aiagents4pharma/talk2scholars/tests/test_utils_pdf_gpu_detection.py +129 -0
  270. aiagents4pharma/talk2scholars/tests/test_utils_pdf_paper_loader.py +116 -0
  271. aiagents4pharma/talk2scholars/tests/test_utils_pdf_rag_pipeline.py +88 -0
  272. aiagents4pharma/talk2scholars/tests/test_utils_pdf_retrieve_chunks.py +190 -0
  273. aiagents4pharma/talk2scholars/tests/test_utils_pdf_singleton_manager.py +159 -0
  274. aiagents4pharma/talk2scholars/tests/test_utils_pdf_vector_normalization.py +121 -0
  275. aiagents4pharma/talk2scholars/tests/test_utils_pdf_vector_store.py +406 -0
  276. aiagents4pharma/talk2scholars/tests/test_utils_pubmed_downloader.py +1007 -0
  277. aiagents4pharma/talk2scholars/tests/test_utils_read_helper_utils.py +106 -0
  278. aiagents4pharma/talk2scholars/tests/test_utils_s2_utils_ext_ids.py +403 -0
  279. aiagents4pharma/talk2scholars/tests/test_utils_tool_helper_utils.py +85 -0
  280. aiagents4pharma/talk2scholars/tests/test_utils_zotero_human_in_the_loop.py +266 -0
  281. aiagents4pharma/talk2scholars/tests/test_utils_zotero_path.py +496 -0
  282. aiagents4pharma/talk2scholars/tests/test_utils_zotero_pdf_downloader_utils.py +46 -0
  283. aiagents4pharma/talk2scholars/tests/test_utils_zotero_read.py +743 -0
  284. aiagents4pharma/talk2scholars/tests/test_utils_zotero_write.py +151 -0
  285. aiagents4pharma/talk2scholars/tools/__init__.py +9 -0
  286. aiagents4pharma/talk2scholars/tools/paper_download/__init__.py +12 -0
  287. aiagents4pharma/talk2scholars/tools/paper_download/paper_downloader.py +442 -0
  288. aiagents4pharma/talk2scholars/tools/paper_download/utils/__init__.py +22 -0
  289. aiagents4pharma/talk2scholars/tools/paper_download/utils/arxiv_downloader.py +207 -0
  290. aiagents4pharma/talk2scholars/tools/paper_download/utils/base_paper_downloader.py +336 -0
  291. aiagents4pharma/talk2scholars/tools/paper_download/utils/biorxiv_downloader.py +313 -0
  292. aiagents4pharma/talk2scholars/tools/paper_download/utils/medrxiv_downloader.py +196 -0
  293. aiagents4pharma/talk2scholars/tools/paper_download/utils/pubmed_downloader.py +323 -0
  294. aiagents4pharma/talk2scholars/tools/pdf/__init__.py +7 -0
  295. aiagents4pharma/talk2scholars/tools/pdf/question_and_answer.py +170 -0
  296. aiagents4pharma/talk2scholars/tools/pdf/utils/__init__.py +37 -0
  297. aiagents4pharma/talk2scholars/tools/pdf/utils/answer_formatter.py +62 -0
  298. aiagents4pharma/talk2scholars/tools/pdf/utils/batch_processor.py +198 -0
  299. aiagents4pharma/talk2scholars/tools/pdf/utils/collection_manager.py +172 -0
  300. aiagents4pharma/talk2scholars/tools/pdf/utils/document_processor.py +76 -0
  301. aiagents4pharma/talk2scholars/tools/pdf/utils/generate_answer.py +97 -0
  302. aiagents4pharma/talk2scholars/tools/pdf/utils/get_vectorstore.py +59 -0
  303. aiagents4pharma/talk2scholars/tools/pdf/utils/gpu_detection.py +150 -0
  304. aiagents4pharma/talk2scholars/tools/pdf/utils/nvidia_nim_reranker.py +97 -0
  305. aiagents4pharma/talk2scholars/tools/pdf/utils/paper_loader.py +123 -0
  306. aiagents4pharma/talk2scholars/tools/pdf/utils/rag_pipeline.py +113 -0
  307. aiagents4pharma/talk2scholars/tools/pdf/utils/retrieve_chunks.py +197 -0
  308. aiagents4pharma/talk2scholars/tools/pdf/utils/singleton_manager.py +140 -0
  309. aiagents4pharma/talk2scholars/tools/pdf/utils/tool_helper.py +86 -0
  310. aiagents4pharma/talk2scholars/tools/pdf/utils/vector_normalization.py +150 -0
  311. aiagents4pharma/talk2scholars/tools/pdf/utils/vector_store.py +327 -0
  312. aiagents4pharma/talk2scholars/tools/s2/__init__.py +21 -0
  313. aiagents4pharma/talk2scholars/tools/s2/display_dataframe.py +110 -0
  314. aiagents4pharma/talk2scholars/tools/s2/multi_paper_rec.py +111 -0
  315. aiagents4pharma/talk2scholars/tools/s2/query_dataframe.py +233 -0
  316. aiagents4pharma/talk2scholars/tools/s2/retrieve_semantic_scholar_paper_id.py +128 -0
  317. aiagents4pharma/talk2scholars/tools/s2/search.py +101 -0
  318. aiagents4pharma/talk2scholars/tools/s2/single_paper_rec.py +102 -0
  319. aiagents4pharma/talk2scholars/tools/s2/utils/__init__.py +5 -0
  320. aiagents4pharma/talk2scholars/tools/s2/utils/multi_helper.py +223 -0
  321. aiagents4pharma/talk2scholars/tools/s2/utils/search_helper.py +205 -0
  322. aiagents4pharma/talk2scholars/tools/s2/utils/single_helper.py +216 -0
  323. aiagents4pharma/talk2scholars/tools/zotero/__init__.py +7 -0
  324. aiagents4pharma/talk2scholars/tools/zotero/utils/__init__.py +7 -0
  325. aiagents4pharma/talk2scholars/tools/zotero/utils/read_helper.py +270 -0
  326. aiagents4pharma/talk2scholars/tools/zotero/utils/review_helper.py +74 -0
  327. aiagents4pharma/talk2scholars/tools/zotero/utils/write_helper.py +194 -0
  328. aiagents4pharma/talk2scholars/tools/zotero/utils/zotero_path.py +180 -0
  329. aiagents4pharma/talk2scholars/tools/zotero/utils/zotero_pdf_downloader.py +133 -0
  330. aiagents4pharma/talk2scholars/tools/zotero/zotero_read.py +105 -0
  331. aiagents4pharma/talk2scholars/tools/zotero/zotero_review.py +162 -0
  332. aiagents4pharma/talk2scholars/tools/zotero/zotero_write.py +91 -0
  333. aiagents4pharma-0.0.0.dist-info/METADATA +335 -0
  334. aiagents4pharma-0.0.0.dist-info/RECORD +336 -0
  335. aiagents4pharma-0.0.0.dist-info/WHEEL +4 -0
  336. aiagents4pharma-0.0.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,762 @@
1
+ """
2
+ Exctraction of multimodal subgraph using Prize-Collecting Steiner Tree (PCST) algorithm.
3
+ """
4
+
5
+ import asyncio
6
+ import logging
7
+ import platform
8
+ import subprocess
9
+ from typing import NamedTuple
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+ import pcst_fast
14
+ from pymilvus import Collection
15
+
16
+ try:
17
+ import cudf # type: ignore
18
+ import cupy as cp # type: ignore
19
+
20
+ CUDF_AVAILABLE = True
21
+ except ImportError:
22
+ CUDF_AVAILABLE = False
23
+ cudf = None
24
+ cp = None
25
+
26
+ # Initialize logger
27
+ logging.basicConfig(level=logging.INFO)
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class SystemDetector:
32
+ """Detect system capabilities and choose appropriate libraries."""
33
+
34
+ def __init__(self):
35
+ self.os_type = platform.system().lower() # 'windows', 'linux', 'darwin'
36
+ self.architecture = platform.machine().lower() # 'x86_64', 'arm64', etc.
37
+ self.has_nvidia_gpu = self._detect_nvidia_gpu()
38
+ self.use_gpu = self.has_nvidia_gpu and self.os_type != "darwin" # No CUDA on macOS
39
+
40
+ logger.info("System Detection Results:")
41
+ logger.info(" OS: %s", self.os_type)
42
+ logger.info(" Architecture: %s", self.architecture)
43
+ logger.info(" NVIDIA GPU detected: %s", self.has_nvidia_gpu)
44
+ logger.info(" Will use GPU acceleration: %s", self.use_gpu)
45
+
46
+ def _detect_nvidia_gpu(self) -> bool:
47
+ """Detect if NVIDIA GPU is available."""
48
+ try:
49
+ # Try nvidia-smi command
50
+ result = subprocess.run(
51
+ ["nvidia-smi"], capture_output=True, text=True, timeout=10, check=False
52
+ )
53
+ return result.returncode == 0
54
+ except (
55
+ subprocess.TimeoutExpired,
56
+ FileNotFoundError,
57
+ subprocess.SubprocessError,
58
+ ):
59
+ return False
60
+
61
+ def get_system_info(self) -> dict:
62
+ """Get comprehensive system information."""
63
+ return {
64
+ "os_type": self.os_type,
65
+ "architecture": self.architecture,
66
+ "has_nvidia_gpu": self.has_nvidia_gpu,
67
+ "use_gpu": self.use_gpu,
68
+ }
69
+
70
+ def is_gpu_compatible(self) -> bool:
71
+ """Check if the system is compatible with GPU acceleration."""
72
+ return self.has_nvidia_gpu and self.os_type != "darwin"
73
+
74
+
75
+ class DynamicLibraryLoader:
76
+ """Dynamically load libraries based on system capabilities."""
77
+
78
+ def __init__(self, detector: SystemDetector):
79
+ self.detector = detector
80
+ self.use_gpu = detector.use_gpu
81
+
82
+ # Initialize attributes that will be set later
83
+ self.py = None
84
+ self.df = None
85
+ self.pd = None
86
+ self.np = None
87
+ self.cudf = None
88
+ self.cp = None
89
+
90
+ # Import libraries based on system capabilities
91
+ self._import_libraries()
92
+
93
+ # Dynamic settings based on hardware
94
+ self.normalize_vectors = self.use_gpu # Only normalize for GPU
95
+ self.metric_type = "IP" if self.use_gpu else "COSINE"
96
+
97
+ logger.info("Library Configuration:")
98
+ logger.info(" Using GPU acceleration: %s", self.use_gpu)
99
+ logger.info(" Vector normalization: %s", self.normalize_vectors)
100
+ logger.info(" Metric type: %s", self.metric_type)
101
+
102
+ def _import_libraries(self):
103
+ """Dynamically import libraries based on system capabilities."""
104
+ # Set base libraries
105
+ self.pd = pd
106
+ self.np = np
107
+
108
+ # Conditionally import GPU libraries
109
+ if self.detector.use_gpu:
110
+ if CUDF_AVAILABLE:
111
+ self.cudf = cudf
112
+ self.cp = cp
113
+ self.py = cp # Use cupy for array operations
114
+ self.df = cudf # Use cudf for dataframes
115
+ logger.info("Successfully imported GPU libraries (cudf, cupy)")
116
+ else:
117
+ logger.error("cudf or cupy not found. Falling back to CPU mode.")
118
+ self.detector.use_gpu = False
119
+ self.use_gpu = False
120
+ self._setup_cpu_mode()
121
+ else:
122
+ self._setup_cpu_mode()
123
+
124
+ def _setup_cpu_mode(self):
125
+ """Setup CPU mode with numpy and pandas."""
126
+ self.py = self.np # Use numpy for array operations
127
+ self.df = self.pd # Use pandas for dataframes
128
+ self.normalize_vectors = False
129
+ self.metric_type = "COSINE"
130
+ logger.info("Using CPU mode with numpy and pandas")
131
+
132
+ def normalize_matrix(self, matrix, axis: int = 1):
133
+ """Normalize matrix using appropriate library."""
134
+ if not self.normalize_vectors:
135
+ return matrix
136
+
137
+ if self.use_gpu:
138
+ # Use cupy for GPU
139
+ matrix_cp = self.cp.asarray(matrix).astype(self.cp.float32)
140
+ norms = self.cp.linalg.norm(matrix_cp, axis=axis, keepdims=True)
141
+ return matrix_cp / norms
142
+ # CPU mode doesn't normalize for COSINE similarity
143
+ return matrix
144
+
145
+ def to_list(self, data):
146
+ """Convert data to list format."""
147
+ if hasattr(data, "tolist"):
148
+ return data.tolist()
149
+ if hasattr(data, "to_arrow"):
150
+ return data.to_arrow().to_pylist()
151
+ return list(data)
152
+
153
+
154
+ class MultimodalPCSTPruning(NamedTuple):
155
+ """
156
+ Prize-Collecting Steiner Tree (PCST) pruning algorithm implementation inspired by G-Retriever
157
+ (He et al., 'G-Retriever: Retrieval-Augmented Generation for Textual Graph Understanding and
158
+ Question Answering', NeurIPS 2024) paper.
159
+ https://arxiv.org/abs/2402.07630
160
+ https://github.com/XiaoxinHe/G-Retriever/blob/main/src/dataset/utils/retrieval.py
161
+
162
+ Args:
163
+ topk: The number of top nodes to consider.
164
+ topk_e: The number of top edges to consider.
165
+ cost_e: The cost of the edges.
166
+ c_const: The constant value for the cost of the edges computation.
167
+ root: The root node of the subgraph, -1 for unrooted.
168
+ num_clusters: The number of clusters.
169
+ pruning: The pruning strategy to use.
170
+ verbosity_level: The verbosity level.
171
+ use_description: Whether to use description embeddings.
172
+ metric_type: The similarity metric type (dynamic based on hardware).
173
+ loader: The dynamic library loader instance.
174
+ """
175
+
176
+ topk: int = 3
177
+ topk_e: int = 3
178
+ cost_e: float = 0.5
179
+ c_const: float = 0.01
180
+ root: int = -1
181
+ num_clusters: int = 1
182
+ pruning: str = "gw"
183
+ verbosity_level: int = 0
184
+ use_description: bool = False
185
+ metric_type: str = None # Will be set dynamically
186
+ loader: DynamicLibraryLoader = None
187
+
188
+ def prepare_collections(self, cfg: dict, modality: str) -> dict:
189
+ """
190
+ Prepare the collections for nodes, node-type specific nodes, and edges in Milvus.
191
+
192
+ Args:
193
+ cfg: The configuration dictionary containing the Milvus setup.
194
+ modality: The modality to use for the subgraph extraction.
195
+
196
+ Returns:
197
+ A dictionary containing the collections of nodes, node-type specific nodes, and edges.
198
+ """
199
+ # Initialize the collections dictionary
200
+ colls = {}
201
+
202
+ # Load the collection for nodes
203
+ colls["nodes"] = Collection(name=f"{cfg.milvus_db.database_name}_nodes")
204
+
205
+ if modality != "prompt":
206
+ # Load the collection for the specific node type
207
+ colls["nodes_type"] = Collection(
208
+ f"{cfg.milvus_db.database_name}_nodes_{modality.replace('/', '_')}"
209
+ )
210
+
211
+ # Load the collection for edges
212
+ colls["edges"] = Collection(name=f"{cfg.milvus_db.database_name}_edges")
213
+
214
+ # Load the collections
215
+ for coll in colls.values():
216
+ coll.load()
217
+
218
+ return colls
219
+
220
+ async def load_edge_index_async(self, cfg: dict, _connection_manager=None) -> np.ndarray:
221
+ """
222
+ Load edge index using hybrid async/sync approach to avoid event loop issues.
223
+
224
+ This method queries the edges collection to get head_index and tail_index,
225
+ eliminating the need for pickle caching and reducing memory usage.
226
+
227
+ Args:
228
+ cfg: The configuration dictionary containing the Milvus setup.
229
+ _connection_manager: Unused parameter for interface compatibility.
230
+
231
+ Returns:
232
+ numpy.ndarray: Edge index array with shape [2, num_edges]
233
+ """
234
+ logger.log(logging.INFO, "Loading edge index from Milvus collection (hybrid)")
235
+
236
+ def load_edges_sync():
237
+ """Load edges synchronously to avoid event loop issues."""
238
+
239
+ collection_name = f"{cfg.milvus_db.database_name}_edges"
240
+ edges_collection = Collection(name=collection_name)
241
+ edges_collection.load()
242
+
243
+ # Query all edges in batches
244
+ batch_size = getattr(cfg.milvus_db, "query_batch_size", 10000)
245
+ total_entities = edges_collection.num_entities
246
+ logger.log(logging.INFO, "Total edges to process: %d", total_entities)
247
+
248
+ head_list = []
249
+ tail_list = []
250
+
251
+ for start in range(0, total_entities, batch_size):
252
+ end = min(start + batch_size, total_entities)
253
+ logger.debug("Processing edge batch: %d to %d", start, end)
254
+
255
+ batch = edges_collection.query(
256
+ expr=f"triplet_index >= {start} and triplet_index < {end}",
257
+ output_fields=["head_index", "tail_index"],
258
+ )
259
+
260
+ head_list.extend([r["head_index"] for r in batch])
261
+ tail_list.extend([r["tail_index"] for r in batch])
262
+
263
+ # Convert to numpy array format expected by PCST
264
+ edge_index = self.loader.py.array([head_list, tail_list])
265
+ logger.log(
266
+ logging.INFO,
267
+ "Edge index loaded (hybrid): shape %s",
268
+ str(edge_index.shape),
269
+ )
270
+
271
+ return edge_index
272
+
273
+ # Run in thread to avoid event loop conflicts
274
+ return await asyncio.to_thread(load_edges_sync)
275
+
276
+ def load_edge_index(self, cfg: dict) -> np.ndarray:
277
+ """
278
+ Load edge index synchronously from Milvus collection.
279
+
280
+ This method queries the edges collection to get head_index and tail_index.
281
+
282
+ Args:
283
+ cfg: The configuration dictionary containing the Milvus setup.
284
+
285
+ Returns:
286
+ numpy.ndarray: Edge index array with shape [2, num_edges]
287
+ """
288
+ logger.log(logging.INFO, "Loading edge index from Milvus collection (sync)")
289
+
290
+ collection_name = f"{cfg.milvus_db.database_name}_edges"
291
+ edges_collection = Collection(name=collection_name)
292
+ edges_collection.load()
293
+
294
+ # Query all edges in batches
295
+ batch_size = getattr(cfg.milvus_db, "query_batch_size", 10000)
296
+ total_entities = edges_collection.num_entities
297
+ logger.log(logging.INFO, "Total edges to process: %d", total_entities)
298
+
299
+ head_list = []
300
+ tail_list = []
301
+
302
+ for start in range(0, total_entities, batch_size):
303
+ end = min(start + batch_size, total_entities)
304
+ logger.debug("Processing edge batch: %d to %d", start, end)
305
+
306
+ batch = edges_collection.query(
307
+ expr=f"triplet_index >= {start} and triplet_index < {end}",
308
+ output_fields=["head_index", "tail_index"],
309
+ )
310
+
311
+ head_list.extend([r["head_index"] for r in batch])
312
+ tail_list.extend([r["tail_index"] for r in batch])
313
+
314
+ # Convert to numpy array format expected by PCST
315
+ edge_index = self.loader.py.array([head_list, tail_list])
316
+ logger.log(
317
+ logging.INFO,
318
+ "Edge index loaded (sync): shape %s",
319
+ str(edge_index.shape),
320
+ )
321
+
322
+ return edge_index
323
+
324
+ def _compute_node_prizes(self, query_emb: list, colls: dict) -> dict:
325
+ """
326
+ Compute the node prizes based on the similarity between the query and nodes.
327
+
328
+ Args:
329
+ query_emb: The query embedding. This can be an embedding of
330
+ a prompt, sequence, or any other feature to be used for the subgraph extraction.
331
+ colls: The collections of nodes, node-type specific nodes, and edges in Milvus.
332
+
333
+ Returns:
334
+ The prizes of the nodes.
335
+ """
336
+ # Initialize several variables
337
+ topk = min(self.topk, colls["nodes"].num_entities)
338
+ n_prizes = self.loader.py.zeros(colls["nodes"].num_entities, dtype=self.loader.py.float32)
339
+
340
+ # Get the actual metric type to use
341
+ actual_metric_type = self.metric_type or self.loader.metric_type
342
+
343
+ # Calculate similarity for text features and update the score
344
+ if self.use_description:
345
+ # Search the collection with the text embedding
346
+ res = colls["nodes"].search(
347
+ data=[query_emb],
348
+ anns_field="desc_emb",
349
+ param={"metric_type": actual_metric_type},
350
+ limit=topk,
351
+ output_fields=["node_id"],
352
+ )
353
+ else:
354
+ # Search the collection with the query embedding
355
+ res = colls["nodes_type"].search(
356
+ data=[query_emb],
357
+ anns_field="feat_emb",
358
+ param={"metric_type": actual_metric_type},
359
+ limit=topk,
360
+ output_fields=["node_id"],
361
+ )
362
+
363
+ # Update the prizes based on the search results
364
+ n_prizes[[r.id for r in res[0]]] = self.loader.py.arange(topk, 0, -1).astype(
365
+ self.loader.py.float32
366
+ )
367
+
368
+ return n_prizes
369
+
370
+ async def _compute_node_prizes_async(
371
+ self,
372
+ query_emb: list,
373
+ collection_name: str,
374
+ connection_manager,
375
+ use_description: bool = False,
376
+ ) -> dict:
377
+ """
378
+ Compute the node prizes asynchronously using connection manager.
379
+
380
+ Args:
381
+ query_emb: The query embedding
382
+ collection_name: Name of the collection to search
383
+ connection_manager: The MilvusConnectionManager instance
384
+ use_description: Whether to use description embeddings
385
+
386
+ Returns:
387
+ The prizes of the nodes
388
+ """
389
+ # Get collection stats for initialization
390
+ stats = await connection_manager.async_get_collection_stats(collection_name)
391
+ num_entities = stats["num_entities"]
392
+
393
+ # Initialize prizes array
394
+ topk = min(self.topk, num_entities)
395
+ n_prizes = self.loader.py.zeros(num_entities, dtype=self.loader.py.float32)
396
+
397
+ # Get the actual metric type to use
398
+ actual_metric_type = self.metric_type or self.loader.metric_type
399
+
400
+ # Determine search field based on use_description
401
+ anns_field = "desc_emb" if use_description else "feat_emb"
402
+
403
+ # Perform async search
404
+ results = await connection_manager.async_search(
405
+ collection_name=collection_name,
406
+ data=[query_emb],
407
+ anns_field=anns_field,
408
+ param={"metric_type": actual_metric_type},
409
+ limit=topk,
410
+ output_fields=["node_id"],
411
+ )
412
+
413
+ # Update the prizes based on the search results
414
+ if results and len(results) > 0:
415
+ result_ids = [hit["id"] for hit in results[0]]
416
+ n_prizes[result_ids] = self.loader.py.arange(topk, 0, -1).astype(self.loader.py.float32)
417
+
418
+ return n_prizes
419
+
420
+ def _compute_edge_prizes(self, text_emb: list, colls: dict):
421
+ """
422
+ Compute the edge prizes based on the similarity between the query and edges.
423
+
424
+ Args:
425
+ text_emb: The textual description embedding.
426
+ colls: The collections of nodes, node-type specific nodes, and edges in Milvus.
427
+
428
+ Returns:
429
+ The prizes of the edges.
430
+ """
431
+ # Initialize several variables
432
+ topk_e = min(self.topk_e, colls["edges"].num_entities)
433
+ e_prizes = self.loader.py.zeros(colls["edges"].num_entities, dtype=self.loader.py.float32)
434
+
435
+ # Get the actual metric type to use
436
+ actual_metric_type = self.metric_type or self.loader.metric_type
437
+
438
+ # Search the collection with the query embedding
439
+ res = colls["edges"].search(
440
+ data=[text_emb],
441
+ anns_field="feat_emb",
442
+ param={"metric_type": actual_metric_type},
443
+ limit=topk_e, # Only retrieve the top-k edges
444
+ output_fields=["head_id", "tail_id"],
445
+ )
446
+
447
+ # Update the prizes based on the search results
448
+ e_prizes[[r.id for r in res[0]]] = [r.score for r in res[0]]
449
+
450
+ # Further process the edge_prizes
451
+ unique_prizes, inverse_indices = self.loader.py.unique(e_prizes, return_inverse=True)
452
+ topk_e_values = unique_prizes[self.loader.py.argsort(-unique_prizes)[:topk_e]]
453
+ last_topk_e_value = topk_e
454
+ for k in range(topk_e):
455
+ indices = inverse_indices == (unique_prizes == topk_e_values[k]).nonzero()[0]
456
+ value = min((topk_e - k) / indices.sum().item(), last_topk_e_value)
457
+ e_prizes[indices] = value
458
+ last_topk_e_value = value * (1 - self.c_const)
459
+
460
+ return e_prizes
461
+
462
+ async def _compute_edge_prizes_async(
463
+ self, text_emb: list, collection_name: str, connection_manager
464
+ ) -> dict:
465
+ """
466
+ Compute the edge prizes asynchronously using connection manager.
467
+
468
+ Args:
469
+ text_emb: The textual description embedding
470
+ collection_name: Name of the edges collection
471
+ connection_manager: The MilvusConnectionManager instance
472
+
473
+ Returns:
474
+ The prizes of the edges
475
+ """
476
+ # Get collection stats for initialization
477
+ stats = await connection_manager.async_get_collection_stats(collection_name)
478
+ num_entities = stats["num_entities"]
479
+
480
+ # Initialize prizes array
481
+ topk_e = min(self.topk_e, num_entities)
482
+ e_prizes = self.loader.py.zeros(num_entities, dtype=self.loader.py.float32)
483
+
484
+ # Get the actual metric type to use
485
+ actual_metric_type = self.metric_type or self.loader.metric_type
486
+
487
+ # Perform async search
488
+ results = await connection_manager.async_search(
489
+ collection_name=collection_name,
490
+ data=[text_emb],
491
+ anns_field="feat_emb",
492
+ param={"metric_type": actual_metric_type},
493
+ limit=topk_e,
494
+ output_fields=["head_id", "tail_id"],
495
+ )
496
+
497
+ # Update the prizes based on the search results
498
+ if results and len(results) > 0:
499
+ result_ids = [hit["id"] for hit in results[0]]
500
+ result_scores = [hit["distance"] for hit in results[0]] # Use distance/score
501
+ e_prizes[result_ids] = result_scores
502
+
503
+ # Process edge prizes using helper method
504
+ return self._process_edge_prizes(e_prizes, topk_e)
505
+
506
+ def _process_edge_prizes(self, e_prizes, topk_e):
507
+ """Helper method to process edge prizes and reduce complexity."""
508
+ unique_prizes, inverse_indices = self.loader.py.unique(e_prizes, return_inverse=True)
509
+ sorted_indices = self.loader.py.argsort(-unique_prizes)[:topk_e]
510
+ topk_e_values = unique_prizes[sorted_indices]
511
+ last_topk_e_value = topk_e
512
+
513
+ for k in range(topk_e):
514
+ indices = inverse_indices == (unique_prizes == topk_e_values[k]).nonzero()[0]
515
+ value = min((topk_e - k) / indices.sum().item(), last_topk_e_value)
516
+ e_prizes[indices] = value
517
+ last_topk_e_value = value * (1 - self.c_const)
518
+
519
+ return e_prizes
520
+
521
+ def compute_prizes(self, text_emb: list, query_emb: list, colls: dict) -> dict:
522
+ """
523
+ Compute the node prizes based on the cosine similarity between the query and nodes,
524
+ as well as the edge prizes based on the cosine similarity between the query and edges.
525
+ Note that the node and edge embeddings shall use the same embedding model and dimensions
526
+ with the query.
527
+
528
+ Args:
529
+ text_emb: The textual description embedding.
530
+ query_emb: The query embedding. This can be an embedding of
531
+ a prompt, sequence, or any other feature to be used for the subgraph extraction.
532
+ colls: The collections of nodes, node-type specific nodes, and edges in Milvus.
533
+
534
+ Returns:
535
+ The prizes of the nodes and edges.
536
+ """
537
+ # Compute prizes for nodes
538
+ logger.log(logging.INFO, "_compute_node_prizes")
539
+ n_prizes = self._compute_node_prizes(query_emb, colls)
540
+
541
+ # Compute prizes for edges
542
+ logger.log(logging.INFO, "_compute_edge_prizes")
543
+ e_prizes = self._compute_edge_prizes(text_emb, colls)
544
+
545
+ return {"nodes": n_prizes, "edges": e_prizes}
546
+
547
+ async def compute_prizes_async(
548
+ self, text_emb: list, query_emb: list, cfg: dict, modality: str
549
+ ) -> dict:
550
+ """
551
+ Compute node and edge prizes asynchronously in parallel using sync fallback.
552
+
553
+ Args:
554
+ text_emb: The textual description embedding
555
+ query_emb: The query embedding
556
+ cfg: The configuration dictionary containing the Milvus setup
557
+ modality: The modality to use for the subgraph extraction
558
+
559
+ Returns:
560
+ The prizes of the nodes and edges
561
+ """
562
+ logger.log(logging.INFO, "Computing prizes in parallel (hybrid async/sync)")
563
+
564
+ # Use existing sync method wrapped in asyncio.to_thread
565
+ colls = self.prepare_collections(cfg, modality)
566
+ return await asyncio.to_thread(self.compute_prizes, text_emb, query_emb, colls)
567
+
568
+ def compute_subgraph_costs(self, edge_index, num_nodes: int, prizes: dict):
569
+ """
570
+ Compute the costs in constructing the subgraph proposed by G-Retriever paper.
571
+
572
+ Args:
573
+ edge_index: The edge index of the graph, consisting of source and destination nodes.
574
+ num_nodes: The number of nodes in the graph.
575
+ prizes: The prizes of the nodes and the edges.
576
+
577
+ Returns:
578
+ edges: The edges of the subgraph, consisting of edges and number of edges without
579
+ virtual edges.
580
+ prizes: The prizes of the subgraph.
581
+ costs: The costs of the subgraph.
582
+ """
583
+ # Initialize several variables
584
+ real_ = {}
585
+ virt_ = {}
586
+
587
+ # Update edge cost threshold
588
+ updated_cost_e = min(
589
+ self.cost_e,
590
+ self.loader.py.max(prizes["edges"]).item() * (1 - self.c_const / 2),
591
+ )
592
+
593
+ # Masks for real and virtual edges
594
+ logger.log(logging.INFO, "Creating masks for real and virtual edges")
595
+ real_["mask"] = prizes["edges"] <= updated_cost_e
596
+ virt_["mask"] = ~real_["mask"]
597
+
598
+ # Real edge indices
599
+ logger.log(logging.INFO, "Computing real edges")
600
+ real_["indices"] = self.loader.py.nonzero(real_["mask"])[0]
601
+ real_["src"] = edge_index[0][real_["indices"]]
602
+ real_["dst"] = edge_index[1][real_["indices"]]
603
+ real_["edges"] = self.loader.py.stack([real_["src"], real_["dst"]], axis=1)
604
+ real_["costs"] = updated_cost_e - prizes["edges"][real_["indices"]]
605
+
606
+ # Edge index mapping: local real edge idx -> original global index
607
+ logger.log(logging.INFO, "Creating mapping for real edges")
608
+ mapping_edges = dict(
609
+ zip(range(len(real_["indices"])), self.loader.to_list(real_["indices"]), strict=False)
610
+ )
611
+
612
+ # Virtual edge handling
613
+ logger.log(logging.INFO, "Computing virtual edges")
614
+ virt_["indices"] = self.loader.py.nonzero(virt_["mask"])[0]
615
+ virt_["src"] = edge_index[0][virt_["indices"]]
616
+ virt_["dst"] = edge_index[1][virt_["indices"]]
617
+ virt_["prizes"] = prizes["edges"][virt_["indices"]] - updated_cost_e
618
+
619
+ # Generate virtual node IDs
620
+ logger.log(logging.INFO, "Generating virtual node IDs")
621
+ virt_["num"] = virt_["indices"].shape[0]
622
+ virt_["node_ids"] = self.loader.py.arange(num_nodes, num_nodes + virt_["num"])
623
+
624
+ # Virtual edges: (src → virtual), (virtual → dst)
625
+ logger.log(logging.INFO, "Creating virtual edges")
626
+ virt_["edges_1"] = self.loader.py.stack([virt_["src"], virt_["node_ids"]], axis=1)
627
+ virt_["edges_2"] = self.loader.py.stack([virt_["node_ids"], virt_["dst"]], axis=1)
628
+ virt_["edges"] = self.loader.py.concatenate([virt_["edges_1"], virt_["edges_2"]], axis=0)
629
+ virt_["costs"] = self.loader.py.zeros(
630
+ (virt_["edges"].shape[0],), dtype=real_["costs"].dtype
631
+ )
632
+
633
+ # Combine real and virtual edges/costs
634
+ logger.log(logging.INFO, "Combining real and virtual edges/costs")
635
+ all_edges = self.loader.py.concatenate([real_["edges"], virt_["edges"]], axis=0)
636
+ all_costs = self.loader.py.concatenate([real_["costs"], virt_["costs"]], axis=0)
637
+
638
+ # Final prizes
639
+ logger.log(logging.INFO, "Getting final prizes")
640
+ final_prizes = self.loader.py.concatenate([prizes["nodes"], virt_["prizes"]], axis=0)
641
+
642
+ # Mapping virtual node ID -> edge index in original graph
643
+ logger.log(logging.INFO, "Creating mapping for virtual nodes")
644
+ mapping_nodes = dict(
645
+ zip(
646
+ self.loader.to_list(virt_["node_ids"]),
647
+ self.loader.to_list(virt_["indices"]),
648
+ strict=False,
649
+ )
650
+ )
651
+
652
+ # Build return values
653
+ logger.log(logging.INFO, "Building return values")
654
+ edges_dict = {
655
+ "edges": all_edges,
656
+ "num_prior_edges": real_["edges"].shape[0],
657
+ }
658
+ mapping = {
659
+ "edges": mapping_edges,
660
+ "nodes": mapping_nodes,
661
+ }
662
+
663
+ return edges_dict, final_prizes, all_costs, mapping
664
+
665
+ def get_subgraph_nodes_edges(
666
+ self, num_nodes: int, vertices, edges_dict: dict, mapping: dict
667
+ ) -> dict:
668
+ """
669
+ Get the selected nodes and edges of the subgraph based on the vertices and edges computed
670
+ by the PCST algorithm.
671
+
672
+ Args:
673
+ num_nodes: The number of nodes in the graph.
674
+ vertices: The vertices selected by the PCST algorithm.
675
+ edges_dict: A dictionary containing the edges and the number of prior edges.
676
+ mapping: A dictionary containing the mapping of nodes and edges.
677
+
678
+ Returns:
679
+ The selected nodes and edges of the extracted subgraph.
680
+ """
681
+ # Get edges information
682
+ edges = edges_dict["edges"]
683
+ num_prior_edges = edges_dict["num_prior_edges"]
684
+
685
+ # Retrieve the selected nodes and edges based on the given vertices and edges
686
+ subgraph_nodes = vertices[vertices < num_nodes]
687
+ subgraph_edges = [mapping["edges"][e.item()] for e in edges if e < num_prior_edges]
688
+ virtual_vertices = vertices[vertices >= num_nodes]
689
+ if len(virtual_vertices) > 0:
690
+ virtual_edges = [mapping["nodes"][i.item()] for i in virtual_vertices]
691
+ subgraph_edges = self.loader.py.array(subgraph_edges + virtual_edges)
692
+ edge_index = edges_dict["edge_index"][:, subgraph_edges]
693
+ subgraph_nodes = self.loader.py.unique(
694
+ self.loader.py.concatenate([subgraph_nodes, edge_index[0], edge_index[1]])
695
+ )
696
+
697
+ return {"nodes": subgraph_nodes, "edges": subgraph_edges}
698
+
699
+ def extract_subgraph(self, text_emb: list, query_emb: list, modality: str, cfg: dict) -> dict:
700
+ """
701
+ Perform the Prize-Collecting Steiner Tree (PCST) algorithm to extract the subgraph.
702
+
703
+ Args:
704
+ text_emb: The textual description embedding.
705
+ query_emb: The query embedding. This can be an embedding of
706
+ a prompt, sequence, or any other feature to be used for the subgraph extraction.
707
+ modality: The modality to use for the subgraph extraction
708
+ (e.g., "text", "sequence", "smiles").
709
+ cfg: The configuration dictionary containing the Milvus setup.
710
+
711
+ Returns:
712
+ The selected nodes and edges of the subgraph.
713
+ """
714
+ # Load the collections for nodes
715
+ logger.log(logging.INFO, "Preparing collections")
716
+ colls = self.prepare_collections(cfg, modality)
717
+
718
+ # Load edge index directly from Milvus (replaces pickle cache)
719
+ logger.log(logging.INFO, "Loading edge index from Milvus")
720
+ edge_index = self.load_edge_index(cfg)
721
+
722
+ # Assert the topk and topk_e values for subgraph retrieval
723
+ assert self.topk > 0, "topk must be greater than or equal to 0"
724
+ assert self.topk_e > 0, "topk_e must be greater than or equal to 0"
725
+
726
+ # Retrieve the top-k nodes and edges based on the query embedding
727
+ logger.log(logging.INFO, "compute_prizes")
728
+ prizes = self.compute_prizes(text_emb, query_emb, colls)
729
+
730
+ # Compute costs in constructing the subgraph
731
+ logger.log(logging.INFO, "compute_subgraph_costs")
732
+ edges_dict, prizes, costs, mapping = self.compute_subgraph_costs(
733
+ edge_index, colls["nodes"].num_entities, prizes
734
+ )
735
+
736
+ # Retrieve the subgraph using the PCST algorithm
737
+ logger.log(logging.INFO, "Running PCST algorithm")
738
+ result_vertices, result_edges = pcst_fast.pcst_fast(
739
+ edges_dict["edges"].tolist(),
740
+ prizes.tolist(),
741
+ costs.tolist(),
742
+ self.root,
743
+ self.num_clusters,
744
+ self.pruning,
745
+ self.verbosity_level,
746
+ )
747
+
748
+ # Get subgraph nodes and edges based on the result of the PCST algorithm
749
+ logger.log(logging.INFO, "Getting subgraph nodes and edges")
750
+ subgraph = self.get_subgraph_nodes_edges(
751
+ colls["nodes"].num_entities,
752
+ self.loader.py.asarray(result_vertices),
753
+ {
754
+ "edges": self.loader.py.asarray(result_edges),
755
+ "num_prior_edges": edges_dict["num_prior_edges"],
756
+ "edge_index": edge_index,
757
+ },
758
+ mapping,
759
+ )
760
+ print(subgraph)
761
+
762
+ return subgraph