aiagents4pharma 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (336) hide show
  1. aiagents4pharma/__init__.py +11 -0
  2. aiagents4pharma/talk2aiagents4pharma/.dockerignore +13 -0
  3. aiagents4pharma/talk2aiagents4pharma/Dockerfile +133 -0
  4. aiagents4pharma/talk2aiagents4pharma/README.md +1 -0
  5. aiagents4pharma/talk2aiagents4pharma/__init__.py +5 -0
  6. aiagents4pharma/talk2aiagents4pharma/agents/__init__.py +6 -0
  7. aiagents4pharma/talk2aiagents4pharma/agents/main_agent.py +70 -0
  8. aiagents4pharma/talk2aiagents4pharma/configs/__init__.py +5 -0
  9. aiagents4pharma/talk2aiagents4pharma/configs/agents/__init__.py +5 -0
  10. aiagents4pharma/talk2aiagents4pharma/configs/agents/main_agent/default.yaml +29 -0
  11. aiagents4pharma/talk2aiagents4pharma/configs/app/__init__.py +0 -0
  12. aiagents4pharma/talk2aiagents4pharma/configs/app/frontend/__init__.py +0 -0
  13. aiagents4pharma/talk2aiagents4pharma/configs/app/frontend/default.yaml +102 -0
  14. aiagents4pharma/talk2aiagents4pharma/configs/config.yaml +4 -0
  15. aiagents4pharma/talk2aiagents4pharma/docker-compose/cpu/.env.example +23 -0
  16. aiagents4pharma/talk2aiagents4pharma/docker-compose/cpu/docker-compose.yml +93 -0
  17. aiagents4pharma/talk2aiagents4pharma/docker-compose/gpu/.env.example +23 -0
  18. aiagents4pharma/talk2aiagents4pharma/docker-compose/gpu/docker-compose.yml +108 -0
  19. aiagents4pharma/talk2aiagents4pharma/install.md +154 -0
  20. aiagents4pharma/talk2aiagents4pharma/states/__init__.py +5 -0
  21. aiagents4pharma/talk2aiagents4pharma/states/state_talk2aiagents4pharma.py +18 -0
  22. aiagents4pharma/talk2aiagents4pharma/tests/__init__.py +3 -0
  23. aiagents4pharma/talk2aiagents4pharma/tests/test_main_agent.py +312 -0
  24. aiagents4pharma/talk2biomodels/.dockerignore +13 -0
  25. aiagents4pharma/talk2biomodels/Dockerfile +104 -0
  26. aiagents4pharma/talk2biomodels/README.md +1 -0
  27. aiagents4pharma/talk2biomodels/__init__.py +5 -0
  28. aiagents4pharma/talk2biomodels/agents/__init__.py +6 -0
  29. aiagents4pharma/talk2biomodels/agents/t2b_agent.py +104 -0
  30. aiagents4pharma/talk2biomodels/api/__init__.py +5 -0
  31. aiagents4pharma/talk2biomodels/api/ols.py +75 -0
  32. aiagents4pharma/talk2biomodels/api/uniprot.py +36 -0
  33. aiagents4pharma/talk2biomodels/configs/__init__.py +5 -0
  34. aiagents4pharma/talk2biomodels/configs/agents/__init__.py +5 -0
  35. aiagents4pharma/talk2biomodels/configs/agents/t2b_agent/__init__.py +3 -0
  36. aiagents4pharma/talk2biomodels/configs/agents/t2b_agent/default.yaml +14 -0
  37. aiagents4pharma/talk2biomodels/configs/app/__init__.py +0 -0
  38. aiagents4pharma/talk2biomodels/configs/app/frontend/__init__.py +0 -0
  39. aiagents4pharma/talk2biomodels/configs/app/frontend/default.yaml +72 -0
  40. aiagents4pharma/talk2biomodels/configs/config.yaml +7 -0
  41. aiagents4pharma/talk2biomodels/configs/tools/__init__.py +5 -0
  42. aiagents4pharma/talk2biomodels/configs/tools/ask_question/__init__.py +3 -0
  43. aiagents4pharma/talk2biomodels/configs/tools/ask_question/default.yaml +30 -0
  44. aiagents4pharma/talk2biomodels/configs/tools/custom_plotter/__init__.py +3 -0
  45. aiagents4pharma/talk2biomodels/configs/tools/custom_plotter/default.yaml +8 -0
  46. aiagents4pharma/talk2biomodels/configs/tools/get_annotation/__init__.py +3 -0
  47. aiagents4pharma/talk2biomodels/configs/tools/get_annotation/default.yaml +8 -0
  48. aiagents4pharma/talk2biomodels/install.md +63 -0
  49. aiagents4pharma/talk2biomodels/models/__init__.py +5 -0
  50. aiagents4pharma/talk2biomodels/models/basico_model.py +125 -0
  51. aiagents4pharma/talk2biomodels/models/sys_bio_model.py +60 -0
  52. aiagents4pharma/talk2biomodels/states/__init__.py +6 -0
  53. aiagents4pharma/talk2biomodels/states/state_talk2biomodels.py +49 -0
  54. aiagents4pharma/talk2biomodels/tests/BIOMD0000000449_url.xml +1585 -0
  55. aiagents4pharma/talk2biomodels/tests/__init__.py +3 -0
  56. aiagents4pharma/talk2biomodels/tests/article_on_model_537.pdf +0 -0
  57. aiagents4pharma/talk2biomodels/tests/test_api.py +31 -0
  58. aiagents4pharma/talk2biomodels/tests/test_ask_question.py +42 -0
  59. aiagents4pharma/talk2biomodels/tests/test_basico_model.py +67 -0
  60. aiagents4pharma/talk2biomodels/tests/test_get_annotation.py +190 -0
  61. aiagents4pharma/talk2biomodels/tests/test_getmodelinfo.py +92 -0
  62. aiagents4pharma/talk2biomodels/tests/test_integration.py +116 -0
  63. aiagents4pharma/talk2biomodels/tests/test_load_biomodel.py +35 -0
  64. aiagents4pharma/talk2biomodels/tests/test_param_scan.py +71 -0
  65. aiagents4pharma/talk2biomodels/tests/test_query_article.py +184 -0
  66. aiagents4pharma/talk2biomodels/tests/test_save_model.py +47 -0
  67. aiagents4pharma/talk2biomodels/tests/test_search_models.py +35 -0
  68. aiagents4pharma/talk2biomodels/tests/test_simulate_model.py +44 -0
  69. aiagents4pharma/talk2biomodels/tests/test_steady_state.py +86 -0
  70. aiagents4pharma/talk2biomodels/tests/test_sys_bio_model.py +67 -0
  71. aiagents4pharma/talk2biomodels/tools/__init__.py +17 -0
  72. aiagents4pharma/talk2biomodels/tools/ask_question.py +125 -0
  73. aiagents4pharma/talk2biomodels/tools/custom_plotter.py +165 -0
  74. aiagents4pharma/talk2biomodels/tools/get_annotation.py +342 -0
  75. aiagents4pharma/talk2biomodels/tools/get_modelinfo.py +159 -0
  76. aiagents4pharma/talk2biomodels/tools/load_arguments.py +134 -0
  77. aiagents4pharma/talk2biomodels/tools/load_biomodel.py +44 -0
  78. aiagents4pharma/talk2biomodels/tools/parameter_scan.py +310 -0
  79. aiagents4pharma/talk2biomodels/tools/query_article.py +64 -0
  80. aiagents4pharma/talk2biomodels/tools/save_model.py +98 -0
  81. aiagents4pharma/talk2biomodels/tools/search_models.py +96 -0
  82. aiagents4pharma/talk2biomodels/tools/simulate_model.py +137 -0
  83. aiagents4pharma/talk2biomodels/tools/steady_state.py +187 -0
  84. aiagents4pharma/talk2biomodels/tools/utils.py +23 -0
  85. aiagents4pharma/talk2cells/README.md +1 -0
  86. aiagents4pharma/talk2cells/__init__.py +5 -0
  87. aiagents4pharma/talk2cells/agents/__init__.py +6 -0
  88. aiagents4pharma/talk2cells/agents/scp_agent.py +87 -0
  89. aiagents4pharma/talk2cells/states/__init__.py +6 -0
  90. aiagents4pharma/talk2cells/states/state_talk2cells.py +15 -0
  91. aiagents4pharma/talk2cells/tests/scp_agent/test_scp_agent.py +22 -0
  92. aiagents4pharma/talk2cells/tools/__init__.py +6 -0
  93. aiagents4pharma/talk2cells/tools/scp_agent/__init__.py +6 -0
  94. aiagents4pharma/talk2cells/tools/scp_agent/display_studies.py +27 -0
  95. aiagents4pharma/talk2cells/tools/scp_agent/search_studies.py +78 -0
  96. aiagents4pharma/talk2knowledgegraphs/.dockerignore +13 -0
  97. aiagents4pharma/talk2knowledgegraphs/Dockerfile +131 -0
  98. aiagents4pharma/talk2knowledgegraphs/README.md +1 -0
  99. aiagents4pharma/talk2knowledgegraphs/__init__.py +5 -0
  100. aiagents4pharma/talk2knowledgegraphs/agents/__init__.py +5 -0
  101. aiagents4pharma/talk2knowledgegraphs/agents/t2kg_agent.py +99 -0
  102. aiagents4pharma/talk2knowledgegraphs/configs/__init__.py +5 -0
  103. aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/__init__.py +3 -0
  104. aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/default.yaml +62 -0
  105. aiagents4pharma/talk2knowledgegraphs/configs/app/__init__.py +5 -0
  106. aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/__init__.py +3 -0
  107. aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/default.yaml +79 -0
  108. aiagents4pharma/talk2knowledgegraphs/configs/config.yaml +13 -0
  109. aiagents4pharma/talk2knowledgegraphs/configs/tools/__init__.py +5 -0
  110. aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/__init__.py +3 -0
  111. aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/default.yaml +24 -0
  112. aiagents4pharma/talk2knowledgegraphs/configs/tools/multimodal_subgraph_extraction/__init__.py +0 -0
  113. aiagents4pharma/talk2knowledgegraphs/configs/tools/multimodal_subgraph_extraction/default.yaml +33 -0
  114. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/__init__.py +3 -0
  115. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/default.yaml +43 -0
  116. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/__init__.py +3 -0
  117. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/default.yaml +9 -0
  118. aiagents4pharma/talk2knowledgegraphs/configs/utils/database/milvus/__init__.py +3 -0
  119. aiagents4pharma/talk2knowledgegraphs/configs/utils/database/milvus/default.yaml +61 -0
  120. aiagents4pharma/talk2knowledgegraphs/configs/utils/enrichments/ols_terms/default.yaml +3 -0
  121. aiagents4pharma/talk2knowledgegraphs/configs/utils/enrichments/reactome_pathways/default.yaml +3 -0
  122. aiagents4pharma/talk2knowledgegraphs/configs/utils/enrichments/uniprot_proteins/default.yaml +6 -0
  123. aiagents4pharma/talk2knowledgegraphs/configs/utils/pubchem_utils/default.yaml +5 -0
  124. aiagents4pharma/talk2knowledgegraphs/datasets/__init__.py +5 -0
  125. aiagents4pharma/talk2knowledgegraphs/datasets/biobridge_primekg.py +607 -0
  126. aiagents4pharma/talk2knowledgegraphs/datasets/dataset.py +25 -0
  127. aiagents4pharma/talk2knowledgegraphs/datasets/primekg.py +212 -0
  128. aiagents4pharma/talk2knowledgegraphs/datasets/starkqa_primekg.py +210 -0
  129. aiagents4pharma/talk2knowledgegraphs/docker-compose/cpu/.env.example +23 -0
  130. aiagents4pharma/talk2knowledgegraphs/docker-compose/cpu/docker-compose.yml +93 -0
  131. aiagents4pharma/talk2knowledgegraphs/docker-compose/gpu/.env.example +23 -0
  132. aiagents4pharma/talk2knowledgegraphs/docker-compose/gpu/docker-compose.yml +108 -0
  133. aiagents4pharma/talk2knowledgegraphs/entrypoint.sh +180 -0
  134. aiagents4pharma/talk2knowledgegraphs/install.md +165 -0
  135. aiagents4pharma/talk2knowledgegraphs/milvus_data_dump.py +886 -0
  136. aiagents4pharma/talk2knowledgegraphs/states/__init__.py +5 -0
  137. aiagents4pharma/talk2knowledgegraphs/states/state_talk2knowledgegraphs.py +40 -0
  138. aiagents4pharma/talk2knowledgegraphs/tests/__init__.py +0 -0
  139. aiagents4pharma/talk2knowledgegraphs/tests/test_agents_t2kg_agent.py +318 -0
  140. aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_biobridge_primekg.py +248 -0
  141. aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_dataset.py +33 -0
  142. aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_primekg.py +86 -0
  143. aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_starkqa_primekg.py +125 -0
  144. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_graphrag_reasoning.py +257 -0
  145. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_milvus_multimodal_subgraph_extraction.py +1444 -0
  146. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_multimodal_subgraph_extraction.py +159 -0
  147. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_extraction.py +152 -0
  148. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_summarization.py +201 -0
  149. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_database_milvus_connection_manager.py +812 -0
  150. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_embeddings.py +51 -0
  151. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_huggingface.py +49 -0
  152. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_nim_molmim.py +59 -0
  153. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_ollama.py +63 -0
  154. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_sentencetransformer.py +47 -0
  155. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_enrichments.py +40 -0
  156. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ollama.py +94 -0
  157. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ols.py +70 -0
  158. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_pubchem.py +45 -0
  159. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_reactome.py +44 -0
  160. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_uniprot.py +48 -0
  161. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_extractions_milvus_multimodal_pcst.py +759 -0
  162. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_kg_utils.py +78 -0
  163. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_pubchem_utils.py +123 -0
  164. aiagents4pharma/talk2knowledgegraphs/tools/__init__.py +11 -0
  165. aiagents4pharma/talk2knowledgegraphs/tools/graphrag_reasoning.py +138 -0
  166. aiagents4pharma/talk2knowledgegraphs/tools/load_arguments.py +22 -0
  167. aiagents4pharma/talk2knowledgegraphs/tools/milvus_multimodal_subgraph_extraction.py +965 -0
  168. aiagents4pharma/talk2knowledgegraphs/tools/multimodal_subgraph_extraction.py +374 -0
  169. aiagents4pharma/talk2knowledgegraphs/tools/subgraph_extraction.py +291 -0
  170. aiagents4pharma/talk2knowledgegraphs/tools/subgraph_summarization.py +123 -0
  171. aiagents4pharma/talk2knowledgegraphs/utils/__init__.py +5 -0
  172. aiagents4pharma/talk2knowledgegraphs/utils/database/__init__.py +5 -0
  173. aiagents4pharma/talk2knowledgegraphs/utils/database/milvus_connection_manager.py +586 -0
  174. aiagents4pharma/talk2knowledgegraphs/utils/embeddings/__init__.py +5 -0
  175. aiagents4pharma/talk2knowledgegraphs/utils/embeddings/embeddings.py +81 -0
  176. aiagents4pharma/talk2knowledgegraphs/utils/embeddings/huggingface.py +111 -0
  177. aiagents4pharma/talk2knowledgegraphs/utils/embeddings/nim_molmim.py +54 -0
  178. aiagents4pharma/talk2knowledgegraphs/utils/embeddings/ollama.py +87 -0
  179. aiagents4pharma/talk2knowledgegraphs/utils/embeddings/sentence_transformer.py +73 -0
  180. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/__init__.py +12 -0
  181. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/enrichments.py +37 -0
  182. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/ollama.py +129 -0
  183. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/ols_terms.py +89 -0
  184. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/pubchem_strings.py +78 -0
  185. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/reactome_pathways.py +71 -0
  186. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/uniprot_proteins.py +98 -0
  187. aiagents4pharma/talk2knowledgegraphs/utils/extractions/__init__.py +5 -0
  188. aiagents4pharma/talk2knowledgegraphs/utils/extractions/milvus_multimodal_pcst.py +762 -0
  189. aiagents4pharma/talk2knowledgegraphs/utils/extractions/multimodal_pcst.py +298 -0
  190. aiagents4pharma/talk2knowledgegraphs/utils/extractions/pcst.py +229 -0
  191. aiagents4pharma/talk2knowledgegraphs/utils/kg_utils.py +67 -0
  192. aiagents4pharma/talk2knowledgegraphs/utils/pubchem_utils.py +104 -0
  193. aiagents4pharma/talk2scholars/.dockerignore +13 -0
  194. aiagents4pharma/talk2scholars/Dockerfile +104 -0
  195. aiagents4pharma/talk2scholars/README.md +1 -0
  196. aiagents4pharma/talk2scholars/__init__.py +7 -0
  197. aiagents4pharma/talk2scholars/agents/__init__.py +13 -0
  198. aiagents4pharma/talk2scholars/agents/main_agent.py +89 -0
  199. aiagents4pharma/talk2scholars/agents/paper_download_agent.py +96 -0
  200. aiagents4pharma/talk2scholars/agents/pdf_agent.py +101 -0
  201. aiagents4pharma/talk2scholars/agents/s2_agent.py +135 -0
  202. aiagents4pharma/talk2scholars/agents/zotero_agent.py +127 -0
  203. aiagents4pharma/talk2scholars/configs/__init__.py +7 -0
  204. aiagents4pharma/talk2scholars/configs/agents/__init__.py +7 -0
  205. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/__init__.py +7 -0
  206. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/main_agent/__init__.py +3 -0
  207. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/main_agent/default.yaml +52 -0
  208. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/paper_download_agent/__init__.py +3 -0
  209. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/paper_download_agent/default.yaml +19 -0
  210. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/pdf_agent/__init__.py +3 -0
  211. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/pdf_agent/default.yaml +19 -0
  212. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/s2_agent/__init__.py +3 -0
  213. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/s2_agent/default.yaml +44 -0
  214. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/zotero_agent/__init__.py +3 -0
  215. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/zotero_agent/default.yaml +19 -0
  216. aiagents4pharma/talk2scholars/configs/app/__init__.py +7 -0
  217. aiagents4pharma/talk2scholars/configs/app/frontend/__init__.py +3 -0
  218. aiagents4pharma/talk2scholars/configs/app/frontend/default.yaml +72 -0
  219. aiagents4pharma/talk2scholars/configs/config.yaml +16 -0
  220. aiagents4pharma/talk2scholars/configs/tools/__init__.py +21 -0
  221. aiagents4pharma/talk2scholars/configs/tools/multi_paper_recommendation/__init__.py +3 -0
  222. aiagents4pharma/talk2scholars/configs/tools/multi_paper_recommendation/default.yaml +26 -0
  223. aiagents4pharma/talk2scholars/configs/tools/paper_download/__init__.py +3 -0
  224. aiagents4pharma/talk2scholars/configs/tools/paper_download/default.yaml +124 -0
  225. aiagents4pharma/talk2scholars/configs/tools/question_and_answer/__init__.py +3 -0
  226. aiagents4pharma/talk2scholars/configs/tools/question_and_answer/default.yaml +62 -0
  227. aiagents4pharma/talk2scholars/configs/tools/retrieve_semantic_scholar_paper_id/__init__.py +3 -0
  228. aiagents4pharma/talk2scholars/configs/tools/retrieve_semantic_scholar_paper_id/default.yaml +12 -0
  229. aiagents4pharma/talk2scholars/configs/tools/search/__init__.py +3 -0
  230. aiagents4pharma/talk2scholars/configs/tools/search/default.yaml +26 -0
  231. aiagents4pharma/talk2scholars/configs/tools/single_paper_recommendation/__init__.py +3 -0
  232. aiagents4pharma/talk2scholars/configs/tools/single_paper_recommendation/default.yaml +26 -0
  233. aiagents4pharma/talk2scholars/configs/tools/zotero_read/__init__.py +3 -0
  234. aiagents4pharma/talk2scholars/configs/tools/zotero_read/default.yaml +57 -0
  235. aiagents4pharma/talk2scholars/configs/tools/zotero_write/__inti__.py +3 -0
  236. aiagents4pharma/talk2scholars/configs/tools/zotero_write/default.yaml +55 -0
  237. aiagents4pharma/talk2scholars/docker-compose/cpu/.env.example +21 -0
  238. aiagents4pharma/talk2scholars/docker-compose/cpu/docker-compose.yml +90 -0
  239. aiagents4pharma/talk2scholars/docker-compose/gpu/.env.example +21 -0
  240. aiagents4pharma/talk2scholars/docker-compose/gpu/docker-compose.yml +105 -0
  241. aiagents4pharma/talk2scholars/install.md +122 -0
  242. aiagents4pharma/talk2scholars/state/__init__.py +7 -0
  243. aiagents4pharma/talk2scholars/state/state_talk2scholars.py +98 -0
  244. aiagents4pharma/talk2scholars/tests/__init__.py +3 -0
  245. aiagents4pharma/talk2scholars/tests/test_agents_main_agent.py +256 -0
  246. aiagents4pharma/talk2scholars/tests/test_agents_paper_agents_download_agent.py +139 -0
  247. aiagents4pharma/talk2scholars/tests/test_agents_pdf_agent.py +114 -0
  248. aiagents4pharma/talk2scholars/tests/test_agents_s2_agent.py +198 -0
  249. aiagents4pharma/talk2scholars/tests/test_agents_zotero_agent.py +160 -0
  250. aiagents4pharma/talk2scholars/tests/test_s2_tools_display_dataframe.py +91 -0
  251. aiagents4pharma/talk2scholars/tests/test_s2_tools_query_dataframe.py +191 -0
  252. aiagents4pharma/talk2scholars/tests/test_states_state.py +38 -0
  253. aiagents4pharma/talk2scholars/tests/test_tools_paper_downloader.py +507 -0
  254. aiagents4pharma/talk2scholars/tests/test_tools_question_and_answer_tool.py +105 -0
  255. aiagents4pharma/talk2scholars/tests/test_tools_s2_multi.py +307 -0
  256. aiagents4pharma/talk2scholars/tests/test_tools_s2_retrieve.py +67 -0
  257. aiagents4pharma/talk2scholars/tests/test_tools_s2_search.py +286 -0
  258. aiagents4pharma/talk2scholars/tests/test_tools_s2_single.py +298 -0
  259. aiagents4pharma/talk2scholars/tests/test_utils_arxiv_downloader.py +469 -0
  260. aiagents4pharma/talk2scholars/tests/test_utils_base_paper_downloader.py +598 -0
  261. aiagents4pharma/talk2scholars/tests/test_utils_biorxiv_downloader.py +669 -0
  262. aiagents4pharma/talk2scholars/tests/test_utils_medrxiv_downloader.py +500 -0
  263. aiagents4pharma/talk2scholars/tests/test_utils_nvidia_nim_reranker.py +117 -0
  264. aiagents4pharma/talk2scholars/tests/test_utils_pdf_answer_formatter.py +67 -0
  265. aiagents4pharma/talk2scholars/tests/test_utils_pdf_batch_processor.py +92 -0
  266. aiagents4pharma/talk2scholars/tests/test_utils_pdf_collection_manager.py +173 -0
  267. aiagents4pharma/talk2scholars/tests/test_utils_pdf_document_processor.py +68 -0
  268. aiagents4pharma/talk2scholars/tests/test_utils_pdf_generate_answer.py +72 -0
  269. aiagents4pharma/talk2scholars/tests/test_utils_pdf_gpu_detection.py +129 -0
  270. aiagents4pharma/talk2scholars/tests/test_utils_pdf_paper_loader.py +116 -0
  271. aiagents4pharma/talk2scholars/tests/test_utils_pdf_rag_pipeline.py +88 -0
  272. aiagents4pharma/talk2scholars/tests/test_utils_pdf_retrieve_chunks.py +190 -0
  273. aiagents4pharma/talk2scholars/tests/test_utils_pdf_singleton_manager.py +159 -0
  274. aiagents4pharma/talk2scholars/tests/test_utils_pdf_vector_normalization.py +121 -0
  275. aiagents4pharma/talk2scholars/tests/test_utils_pdf_vector_store.py +406 -0
  276. aiagents4pharma/talk2scholars/tests/test_utils_pubmed_downloader.py +1007 -0
  277. aiagents4pharma/talk2scholars/tests/test_utils_read_helper_utils.py +106 -0
  278. aiagents4pharma/talk2scholars/tests/test_utils_s2_utils_ext_ids.py +403 -0
  279. aiagents4pharma/talk2scholars/tests/test_utils_tool_helper_utils.py +85 -0
  280. aiagents4pharma/talk2scholars/tests/test_utils_zotero_human_in_the_loop.py +266 -0
  281. aiagents4pharma/talk2scholars/tests/test_utils_zotero_path.py +496 -0
  282. aiagents4pharma/talk2scholars/tests/test_utils_zotero_pdf_downloader_utils.py +46 -0
  283. aiagents4pharma/talk2scholars/tests/test_utils_zotero_read.py +743 -0
  284. aiagents4pharma/talk2scholars/tests/test_utils_zotero_write.py +151 -0
  285. aiagents4pharma/talk2scholars/tools/__init__.py +9 -0
  286. aiagents4pharma/talk2scholars/tools/paper_download/__init__.py +12 -0
  287. aiagents4pharma/talk2scholars/tools/paper_download/paper_downloader.py +442 -0
  288. aiagents4pharma/talk2scholars/tools/paper_download/utils/__init__.py +22 -0
  289. aiagents4pharma/talk2scholars/tools/paper_download/utils/arxiv_downloader.py +207 -0
  290. aiagents4pharma/talk2scholars/tools/paper_download/utils/base_paper_downloader.py +336 -0
  291. aiagents4pharma/talk2scholars/tools/paper_download/utils/biorxiv_downloader.py +313 -0
  292. aiagents4pharma/talk2scholars/tools/paper_download/utils/medrxiv_downloader.py +196 -0
  293. aiagents4pharma/talk2scholars/tools/paper_download/utils/pubmed_downloader.py +323 -0
  294. aiagents4pharma/talk2scholars/tools/pdf/__init__.py +7 -0
  295. aiagents4pharma/talk2scholars/tools/pdf/question_and_answer.py +170 -0
  296. aiagents4pharma/talk2scholars/tools/pdf/utils/__init__.py +37 -0
  297. aiagents4pharma/talk2scholars/tools/pdf/utils/answer_formatter.py +62 -0
  298. aiagents4pharma/talk2scholars/tools/pdf/utils/batch_processor.py +198 -0
  299. aiagents4pharma/talk2scholars/tools/pdf/utils/collection_manager.py +172 -0
  300. aiagents4pharma/talk2scholars/tools/pdf/utils/document_processor.py +76 -0
  301. aiagents4pharma/talk2scholars/tools/pdf/utils/generate_answer.py +97 -0
  302. aiagents4pharma/talk2scholars/tools/pdf/utils/get_vectorstore.py +59 -0
  303. aiagents4pharma/talk2scholars/tools/pdf/utils/gpu_detection.py +150 -0
  304. aiagents4pharma/talk2scholars/tools/pdf/utils/nvidia_nim_reranker.py +97 -0
  305. aiagents4pharma/talk2scholars/tools/pdf/utils/paper_loader.py +123 -0
  306. aiagents4pharma/talk2scholars/tools/pdf/utils/rag_pipeline.py +113 -0
  307. aiagents4pharma/talk2scholars/tools/pdf/utils/retrieve_chunks.py +197 -0
  308. aiagents4pharma/talk2scholars/tools/pdf/utils/singleton_manager.py +140 -0
  309. aiagents4pharma/talk2scholars/tools/pdf/utils/tool_helper.py +86 -0
  310. aiagents4pharma/talk2scholars/tools/pdf/utils/vector_normalization.py +150 -0
  311. aiagents4pharma/talk2scholars/tools/pdf/utils/vector_store.py +327 -0
  312. aiagents4pharma/talk2scholars/tools/s2/__init__.py +21 -0
  313. aiagents4pharma/talk2scholars/tools/s2/display_dataframe.py +110 -0
  314. aiagents4pharma/talk2scholars/tools/s2/multi_paper_rec.py +111 -0
  315. aiagents4pharma/talk2scholars/tools/s2/query_dataframe.py +233 -0
  316. aiagents4pharma/talk2scholars/tools/s2/retrieve_semantic_scholar_paper_id.py +128 -0
  317. aiagents4pharma/talk2scholars/tools/s2/search.py +101 -0
  318. aiagents4pharma/talk2scholars/tools/s2/single_paper_rec.py +102 -0
  319. aiagents4pharma/talk2scholars/tools/s2/utils/__init__.py +5 -0
  320. aiagents4pharma/talk2scholars/tools/s2/utils/multi_helper.py +223 -0
  321. aiagents4pharma/talk2scholars/tools/s2/utils/search_helper.py +205 -0
  322. aiagents4pharma/talk2scholars/tools/s2/utils/single_helper.py +216 -0
  323. aiagents4pharma/talk2scholars/tools/zotero/__init__.py +7 -0
  324. aiagents4pharma/talk2scholars/tools/zotero/utils/__init__.py +7 -0
  325. aiagents4pharma/talk2scholars/tools/zotero/utils/read_helper.py +270 -0
  326. aiagents4pharma/talk2scholars/tools/zotero/utils/review_helper.py +74 -0
  327. aiagents4pharma/talk2scholars/tools/zotero/utils/write_helper.py +194 -0
  328. aiagents4pharma/talk2scholars/tools/zotero/utils/zotero_path.py +180 -0
  329. aiagents4pharma/talk2scholars/tools/zotero/utils/zotero_pdf_downloader.py +133 -0
  330. aiagents4pharma/talk2scholars/tools/zotero/zotero_read.py +105 -0
  331. aiagents4pharma/talk2scholars/tools/zotero/zotero_review.py +162 -0
  332. aiagents4pharma/talk2scholars/tools/zotero/zotero_write.py +91 -0
  333. aiagents4pharma-0.0.0.dist-info/METADATA +335 -0
  334. aiagents4pharma-0.0.0.dist-info/RECORD +336 -0
  335. aiagents4pharma-0.0.0.dist-info/WHEEL +4 -0
  336. aiagents4pharma-0.0.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,298 @@
1
+ """
2
+ Exctraction of multimodal subgraph using Prize-Collecting Steiner Tree (PCST) algorithm.
3
+ """
4
+
5
+ from typing import NamedTuple
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ import pcst_fast
10
+ import torch
11
+ from torch_geometric.data.data import Data
12
+
13
+
14
+ class MultimodalPCSTPruning(NamedTuple):
15
+ """
16
+ Prize-Collecting Steiner Tree (PCST) pruning algorithm implementation inspired by G-Retriever
17
+ (He et al., 'G-Retriever: Retrieval-Augmented Generation for Textual Graph Understanding and
18
+ Question Answering', NeurIPS 2024) paper.
19
+ https://arxiv.org/abs/2402.07630
20
+ https://github.com/XiaoxinHe/G-Retriever/blob/main/src/dataset/utils/retrieval.py
21
+
22
+ Args:
23
+ topk: The number of top nodes to consider.
24
+ topk_e: The number of top edges to consider.
25
+ cost_e: The cost of the edges.
26
+ c_const: The constant value for the cost of the edges computation.
27
+ root: The root node of the subgraph, -1 for unrooted.
28
+ num_clusters: The number of clusters.
29
+ pruning: The pruning strategy to use.
30
+ verbosity_level: The verbosity level.
31
+ """
32
+
33
+ topk: int = 3
34
+ topk_e: int = 3
35
+ cost_e: float = 0.5
36
+ c_const: float = 0.01
37
+ root: int = -1
38
+ num_clusters: int = 1
39
+ pruning: str = "gw"
40
+ verbosity_level: int = 0
41
+ use_description: bool = False
42
+
43
+ def _compute_node_prizes(self, graph: Data, query_emb: torch.Tensor, modality: str):
44
+ """
45
+ Compute the node prizes based on the cosine similarity between the query and nodes.
46
+
47
+ Args:
48
+ graph: The knowledge graph in PyTorch Geometric Data format.
49
+ query_emb: The query embedding in PyTorch Tensor format. This can be an embedding of
50
+ a prompt, sequence, or any other feature to be used for the subgraph extraction.
51
+ modality: The modality to use for the subgraph extraction based on the node type.
52
+
53
+ Returns:
54
+ The prizes of the nodes.
55
+ """
56
+ # Convert PyG graph to a DataFrame
57
+ graph_df = pd.DataFrame(
58
+ {
59
+ "node_type": graph.node_type,
60
+ "desc_x": [x.tolist() for x in graph.desc_x],
61
+ "x": [list(x) for x in graph.x],
62
+ "score": [0.0 for _ in range(len(graph.node_id))],
63
+ }
64
+ )
65
+
66
+ # Calculate cosine similarity for text features and update the score
67
+ if self.use_description:
68
+ graph_df.loc[:, "score"] = torch.nn.CosineSimilarity(dim=-1)(
69
+ query_emb,
70
+ torch.tensor(list(graph_df.desc_x.values)), # Using textual description features
71
+ ).tolist()
72
+ else:
73
+ graph_df.loc[graph_df["node_type"] == modality, "score"] = torch.nn.CosineSimilarity(
74
+ dim=-1
75
+ )(
76
+ query_emb,
77
+ torch.tensor(list(graph_df[graph_df["node_type"] == modality].x.values)),
78
+ ).tolist()
79
+
80
+ # Set the prizes for nodes based on the similarity scores
81
+ n_prizes = torch.tensor(graph_df.score.values, dtype=torch.float32)
82
+ # n_prizes = torch.nn.CosineSimilarity(dim=-1)(query_emb, graph.x)
83
+ topk = min(self.topk, graph.num_nodes)
84
+ _, topk_n_indices = torch.topk(n_prizes, topk, largest=True)
85
+ n_prizes = torch.zeros_like(n_prizes)
86
+ n_prizes[topk_n_indices] = torch.arange(topk, 0, -1).float()
87
+
88
+ return n_prizes
89
+
90
+ def _compute_edge_prizes(self, graph: Data, text_emb: torch.Tensor):
91
+ """
92
+ Compute the node prizes based on the cosine similarity between the query and nodes.
93
+
94
+ Args:
95
+ graph: The knowledge graph in PyTorch Geometric Data format.
96
+ text_emb: The textual description embedding in PyTorch Tensor format.
97
+
98
+ Returns:
99
+ The prizes of the nodes.
100
+ """
101
+ # Note that as of now, the edge features are based on textual features
102
+ # Compute prizes for edges
103
+ e_prizes = torch.nn.CosineSimilarity(dim=-1)(text_emb, graph.edge_attr)
104
+ unique_prizes, inverse_indices = e_prizes.unique(return_inverse=True)
105
+ topk_e = min(self.topk_e, unique_prizes.size(0))
106
+ topk_e_values, _ = torch.topk(unique_prizes, topk_e, largest=True)
107
+ e_prizes[e_prizes < topk_e_values[-1]] = 0.0
108
+ last_topk_e_value = topk_e
109
+ for k in range(topk_e):
110
+ indices = (
111
+ inverse_indices == (unique_prizes == topk_e_values[k]).nonzero(as_tuple=True)[0]
112
+ )
113
+ value = min((topk_e - k) / indices.sum().item(), last_topk_e_value)
114
+ e_prizes[indices] = value
115
+ last_topk_e_value = value * (1 - self.c_const)
116
+
117
+ return e_prizes
118
+
119
+ def compute_prizes(
120
+ self,
121
+ graph: Data,
122
+ text_emb: torch.Tensor,
123
+ query_emb: torch.Tensor,
124
+ modality: str,
125
+ ):
126
+ """
127
+ Compute the node prizes based on the cosine similarity between the query and nodes,
128
+ as well as the edge prizes based on the cosine similarity between the query and edges.
129
+ Note that the node and edge embeddings shall use the same embedding model and dimensions
130
+ with the query.
131
+
132
+ Args:
133
+ graph: The knowledge graph in PyTorch Geometric Data format.
134
+ text_emb: The textual description embedding in PyTorch Tensor format.
135
+ query_emb: The query embedding in PyTorch Tensor format. This can be an embedding of
136
+ a prompt, sequence, or any other feature to be used for the subgraph extraction.
137
+ modality: The modality to use for the subgraph extraction based on node type.
138
+
139
+ Returns:
140
+ The prizes of the nodes and edges.
141
+ """
142
+ # Compute prizes for nodes
143
+ n_prizes = self._compute_node_prizes(graph, query_emb, modality)
144
+
145
+ # Compute prizes for edges
146
+ e_prizes = self._compute_edge_prizes(graph, text_emb)
147
+
148
+ return {"nodes": n_prizes, "edges": e_prizes}
149
+
150
+ def compute_subgraph_costs(
151
+ self, graph: Data, prizes: dict
152
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
153
+ """
154
+ Compute the costs in constructing the subgraph proposed by G-Retriever paper.
155
+
156
+ Args:
157
+ graph: The knowledge graph in PyTorch Geometric Data format.
158
+ prizes: The prizes of the nodes and the edges.
159
+
160
+ Returns:
161
+ edges: The edges of the subgraph, consisting of edges and number of edges without
162
+ virtual edges.
163
+ prizes: The prizes of the subgraph.
164
+ costs: The costs of the subgraph.
165
+ """
166
+ # Logic to reduce the cost of the edges such that at least one edge is selected
167
+ updated_cost_e = min(
168
+ self.cost_e,
169
+ prizes["edges"].max().item() * (1 - self.c_const / 2),
170
+ )
171
+
172
+ # Initialize variables
173
+ edges = []
174
+ costs = []
175
+ virtual = {
176
+ "n_prizes": [],
177
+ "edges": [],
178
+ "costs": [],
179
+ }
180
+ mapping = {"nodes": {}, "edges": {}}
181
+
182
+ # Compute the costs, edges, and virtual variables based on the prizes
183
+ for i, (src, dst) in enumerate(graph.edge_index.T.numpy()):
184
+ prize_e = prizes["edges"][i]
185
+ if prize_e <= updated_cost_e:
186
+ mapping["edges"][len(edges)] = i
187
+ edges.append((src, dst))
188
+ costs.append(updated_cost_e - prize_e)
189
+ else:
190
+ virtual_node_id = graph.num_nodes + len(virtual["n_prizes"])
191
+ mapping["nodes"][virtual_node_id] = i
192
+ virtual["edges"].append((src, virtual_node_id))
193
+ virtual["edges"].append((virtual_node_id, dst))
194
+ virtual["costs"].append(0)
195
+ virtual["costs"].append(0)
196
+ virtual["n_prizes"].append(prize_e - updated_cost_e)
197
+ prizes = np.concatenate([prizes["nodes"], np.array(virtual["n_prizes"])])
198
+ edges_dict = {}
199
+ edges_dict["edges"] = edges
200
+ edges_dict["num_prior_edges"] = len(edges)
201
+ # Final computation of the costs and edges based on the virtual costs and virtual edges
202
+ if len(virtual["costs"]) > 0:
203
+ costs = np.array(costs + virtual["costs"])
204
+ edges = np.array(edges + virtual["edges"])
205
+ edges_dict["edges"] = edges
206
+
207
+ return edges_dict, prizes, costs, mapping
208
+
209
+ def get_subgraph_nodes_edges(
210
+ self,
211
+ graph: Data,
212
+ vertices: np.ndarray,
213
+ edges_dict: dict,
214
+ mapping: dict,
215
+ ) -> dict:
216
+ """
217
+ Get the selected nodes and edges of the subgraph based on the vertices and edges computed
218
+ by the PCST algorithm.
219
+
220
+ Args:
221
+ graph: The knowledge graph in PyTorch Geometric Data format.
222
+ vertices: The vertices of the subgraph computed by the PCST algorithm.
223
+ edges_dict: The dictionary of edges of the subgraph computed by the PCST algorithm,
224
+ and the number of prior edges (without virtual edges).
225
+ mapping: The mapping dictionary of the nodes and edges.
226
+ num_prior_edges: The number of edges before adding virtual edges.
227
+
228
+ Returns:
229
+ The selected nodes and edges of the extracted subgraph.
230
+ """
231
+ # Get edges information
232
+ edges = edges_dict["edges"]
233
+ num_prior_edges = edges_dict["num_prior_edges"]
234
+ # Retrieve the selected nodes and edges based on the given vertices and edges
235
+ subgraph_nodes = vertices[vertices < graph.num_nodes]
236
+ subgraph_edges = [mapping["edges"][e] for e in edges if e < num_prior_edges]
237
+ virtual_vertices = vertices[vertices >= graph.num_nodes]
238
+ if len(virtual_vertices) > 0:
239
+ virtual_vertices = vertices[vertices >= graph.num_nodes]
240
+ virtual_edges = [mapping["nodes"][i] for i in virtual_vertices]
241
+ subgraph_edges = np.array(subgraph_edges + virtual_edges)
242
+ edge_index = graph.edge_index[:, subgraph_edges]
243
+ subgraph_nodes = np.unique(
244
+ np.concatenate([subgraph_nodes, edge_index[0].numpy(), edge_index[1].numpy()])
245
+ )
246
+
247
+ return {"nodes": subgraph_nodes, "edges": subgraph_edges}
248
+
249
+ def extract_subgraph(
250
+ self,
251
+ graph: Data,
252
+ text_emb: torch.Tensor,
253
+ query_emb: torch.Tensor,
254
+ modality: str,
255
+ ) -> dict:
256
+ """
257
+ Perform the Prize-Collecting Steiner Tree (PCST) algorithm to extract the subgraph.
258
+
259
+ Args:
260
+ graph: The knowledge graph in PyTorch Geometric Data format.
261
+ text_emb: The textual description embedding in PyTorch Tensor format.
262
+ query_emb: The query embedding in PyTorch Tensor format. This can be an embedding of
263
+ a prompt, sequence, or any other feature to be used for the subgraph extraction.
264
+ modality: The modality to use for the subgraph extraction
265
+ (e.g., "text", "sequence", "smiles").
266
+
267
+ Returns:
268
+ The selected nodes and edges of the subgraph.
269
+ """
270
+ # Assert the topk and topk_e values for subgraph retrieval
271
+ assert self.topk > 0, "topk must be greater than or equal to 0"
272
+ assert self.topk_e > 0, "topk_e must be greater than or equal to 0"
273
+
274
+ # Retrieve the top-k nodes and edges based on the query embedding
275
+ prizes = self.compute_prizes(graph, text_emb, query_emb, modality)
276
+
277
+ # Compute costs in constructing the subgraph
278
+ edges_dict, prizes, costs, mapping = self.compute_subgraph_costs(graph, prizes)
279
+
280
+ # Retrieve the subgraph using the PCST algorithm
281
+ result_vertices, result_edges = pcst_fast.pcst_fast(
282
+ edges_dict["edges"],
283
+ prizes,
284
+ costs,
285
+ self.root,
286
+ self.num_clusters,
287
+ self.pruning,
288
+ self.verbosity_level,
289
+ )
290
+
291
+ subgraph = self.get_subgraph_nodes_edges(
292
+ graph,
293
+ result_vertices,
294
+ {"edges": result_edges, "num_prior_edges": edges_dict["num_prior_edges"]},
295
+ mapping,
296
+ )
297
+
298
+ return subgraph
@@ -0,0 +1,229 @@
1
+ """
2
+ Exctraction of subgraph using Prize-Collecting Steiner Tree (PCST) algorithm.
3
+ """
4
+
5
+ from typing import NamedTuple
6
+
7
+ import numpy as np
8
+ import pcst_fast
9
+ import torch
10
+ from torch_geometric.data.data import Data
11
+
12
+
13
+ class PCSTPruning(NamedTuple):
14
+ """
15
+ Prize-Collecting Steiner Tree (PCST) pruning algorithm implementation inspired by G-Retriever
16
+ (He et al., 'G-Retriever: Retrieval-Augmented Generation for Textual Graph Understanding and
17
+ Question Answering', NeurIPS 2024) paper.
18
+ https://arxiv.org/abs/2402.07630
19
+ https://github.com/XiaoxinHe/G-Retriever/blob/main/src/dataset/utils/retrieval.py
20
+
21
+ Args:
22
+ topk: The number of top nodes to consider.
23
+ topk_e: The number of top edges to consider.
24
+ cost_e: The cost of the edges.
25
+ c_const: The constant value for the cost of the edges computation.
26
+ root: The root node of the subgraph, -1 for unrooted.
27
+ num_clusters: The number of clusters.
28
+ pruning: The pruning strategy to use.
29
+ verbosity_level: The verbosity level.
30
+ """
31
+
32
+ topk: int = 3
33
+ topk_e: int = 3
34
+ cost_e: float = 0.5
35
+ c_const: float = 0.01
36
+ root: int = -1
37
+ num_clusters: int = 1
38
+ pruning: str = "gw"
39
+ verbosity_level: int = 0
40
+
41
+ def compute_prizes(self, graph: Data, query_emb: torch.Tensor) -> np.ndarray:
42
+ """
43
+ Compute the node prizes based on the cosine similarity between the query and nodes,
44
+ as well as the edge prizes based on the cosine similarity between the query and edges.
45
+ Note that the node and edge embeddings shall use the same embedding model and dimensions
46
+ with the query.
47
+
48
+ Args:
49
+ graph: The knowledge graph in PyTorch Geometric Data format.
50
+ query_emb: The query embedding in PyTorch Tensor format.
51
+
52
+ Returns:
53
+ The prizes of the nodes and edges.
54
+ """
55
+ # Compute prizes for nodes
56
+ n_prizes = torch.nn.CosineSimilarity(dim=-1)(query_emb, graph.x)
57
+ topk = min(self.topk, graph.num_nodes)
58
+ _, topk_n_indices = torch.topk(n_prizes, topk, largest=True)
59
+ n_prizes = torch.zeros_like(n_prizes)
60
+ n_prizes[topk_n_indices] = torch.arange(topk, 0, -1).float()
61
+
62
+ # Compute prizes for edges
63
+ # e_prizes = torch.nn.CosineSimilarity(dim=-1)(query_emb, graph.edge_attr)
64
+ # topk_e = min(self.topk_e, e_prizes.unique().size(0))
65
+ # topk_e_values, _ = torch.topk(e_prizes.unique(), topk_e, largest=True)
66
+ # e_prizes[e_prizes < topk_e_values[-1]] = 0.0
67
+ # last_topk_e_value = topk_e
68
+ # for k in range(topk_e):
69
+ # indices = e_prizes == topk_e_values[k]
70
+ # value = min((topk_e - k) / sum(indices), last_topk_e_value)
71
+ # e_prizes[indices] = value
72
+ # last_topk_e_value = value * (1 - self.c_const)
73
+
74
+ # Optimized version of the above code
75
+ e_prizes = torch.nn.CosineSimilarity(dim=-1)(query_emb, graph.edge_attr)
76
+ unique_prizes, inverse_indices = e_prizes.unique(return_inverse=True)
77
+ topk_e = min(self.topk_e, unique_prizes.size(0))
78
+ topk_e_values, _ = torch.topk(unique_prizes, topk_e, largest=True)
79
+ e_prizes[e_prizes < topk_e_values[-1]] = 0.0
80
+ last_topk_e_value = topk_e
81
+ for k in range(topk_e):
82
+ indices = (
83
+ inverse_indices == (unique_prizes == topk_e_values[k]).nonzero(as_tuple=True)[0]
84
+ )
85
+ value = min((topk_e - k) / indices.sum().item(), last_topk_e_value)
86
+ e_prizes[indices] = value
87
+ last_topk_e_value = value * (1 - self.c_const)
88
+
89
+ return {"nodes": n_prizes, "edges": e_prizes}
90
+
91
+ def compute_subgraph_costs(
92
+ self, graph: Data, prizes: dict
93
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
94
+ """
95
+ Compute the costs in constructing the subgraph proposed by G-Retriever paper.
96
+
97
+ Args:
98
+ graph: The knowledge graph in PyTorch Geometric Data format.
99
+ prizes: The prizes of the nodes and the edges.
100
+
101
+ Returns:
102
+ edges: The edges of the subgraph, consisting of edges and number of edges without
103
+ virtual edges.
104
+ prizes: The prizes of the subgraph.
105
+ costs: The costs of the subgraph.
106
+ """
107
+ # Logic to reduce the cost of the edges such that at least one edge is selected
108
+ updated_cost_e = min(
109
+ self.cost_e,
110
+ prizes["edges"].max().item() * (1 - self.c_const / 2),
111
+ )
112
+
113
+ # Initialize variables
114
+ edges = []
115
+ costs = []
116
+ virtual = {
117
+ "n_prizes": [],
118
+ "edges": [],
119
+ "costs": [],
120
+ }
121
+ mapping = {"nodes": {}, "edges": {}}
122
+
123
+ # Compute the costs, edges, and virtual variables based on the prizes
124
+ for i, (src, dst) in enumerate(graph.edge_index.T.numpy()):
125
+ prize_e = prizes["edges"][i]
126
+ if prize_e <= updated_cost_e:
127
+ mapping["edges"][len(edges)] = i
128
+ edges.append((src, dst))
129
+ costs.append(updated_cost_e - prize_e)
130
+ else:
131
+ virtual_node_id = graph.num_nodes + len(virtual["n_prizes"])
132
+ mapping["nodes"][virtual_node_id] = i
133
+ virtual["edges"].append((src, virtual_node_id))
134
+ virtual["edges"].append((virtual_node_id, dst))
135
+ virtual["costs"].append(0)
136
+ virtual["costs"].append(0)
137
+ virtual["n_prizes"].append(prize_e - updated_cost_e)
138
+ prizes = np.concatenate([prizes["nodes"], np.array(virtual["n_prizes"])])
139
+ edges_dict = {}
140
+ edges_dict["edges"] = edges
141
+ edges_dict["num_prior_edges"] = len(edges)
142
+ # Final computation of the costs and edges based on the virtual costs and virtual edges
143
+ if len(virtual["costs"]) > 0:
144
+ costs = np.array(costs + virtual["costs"])
145
+ edges = np.array(edges + virtual["edges"])
146
+ edges_dict["edges"] = edges
147
+
148
+ return edges_dict, prizes, costs, mapping
149
+
150
+ def get_subgraph_nodes_edges(
151
+ self,
152
+ graph: Data,
153
+ vertices: np.ndarray,
154
+ edges_dict: dict,
155
+ mapping: dict,
156
+ ) -> dict:
157
+ """
158
+ Get the selected nodes and edges of the subgraph based on the vertices and edges computed
159
+ by the PCST algorithm.
160
+
161
+ Args:
162
+ graph: The knowledge graph in PyTorch Geometric Data format.
163
+ vertices: The vertices of the subgraph computed by the PCST algorithm.
164
+ edges_dict: The dictionary of edges of the subgraph computed by the PCST algorithm,
165
+ and the number of prior edges (without virtual edges).
166
+ mapping: The mapping dictionary of the nodes and edges.
167
+ num_prior_edges: The number of edges before adding virtual edges.
168
+
169
+ Returns:
170
+ The selected nodes and edges of the extracted subgraph.
171
+ """
172
+ # Get edges information
173
+ edges = edges_dict["edges"]
174
+ num_prior_edges = edges_dict["num_prior_edges"]
175
+ # Retrieve the selected nodes and edges based on the given vertices and edges
176
+ subgraph_nodes = vertices[vertices < graph.num_nodes]
177
+ subgraph_edges = [mapping["edges"][e] for e in edges if e < num_prior_edges]
178
+ virtual_vertices = vertices[vertices >= graph.num_nodes]
179
+ if len(virtual_vertices) > 0:
180
+ virtual_vertices = vertices[vertices >= graph.num_nodes]
181
+ virtual_edges = [mapping["nodes"][i] for i in virtual_vertices]
182
+ subgraph_edges = np.array(subgraph_edges + virtual_edges)
183
+ edge_index = graph.edge_index[:, subgraph_edges]
184
+ subgraph_nodes = np.unique(
185
+ np.concatenate([subgraph_nodes, edge_index[0].numpy(), edge_index[1].numpy()])
186
+ )
187
+
188
+ return {"nodes": subgraph_nodes, "edges": subgraph_edges}
189
+
190
+ def extract_subgraph(self, graph: Data, query_emb: torch.Tensor) -> dict:
191
+ """
192
+ Perform the Prize-Collecting Steiner Tree (PCST) algorithm to extract the subgraph.
193
+
194
+ Args:
195
+ graph: The knowledge graph in PyTorch Geometric Data format.
196
+ query_emb: The query embedding.
197
+
198
+ Returns:
199
+ The selected nodes and edges of the subgraph.
200
+ """
201
+ # Assert the topk and topk_e values for subgraph retrieval
202
+ assert self.topk > 0, "topk must be greater than or equal to 0"
203
+ assert self.topk_e > 0, "topk_e must be greater than or equal to 0"
204
+
205
+ # Retrieve the top-k nodes and edges based on the query embedding
206
+ prizes = self.compute_prizes(graph, query_emb)
207
+
208
+ # Compute costs in constructing the subgraph
209
+ edges_dict, prizes, costs, mapping = self.compute_subgraph_costs(graph, prizes)
210
+
211
+ # Retrieve the subgraph using the PCST algorithm
212
+ result_vertices, result_edges = pcst_fast.pcst_fast(
213
+ edges_dict["edges"],
214
+ prizes,
215
+ costs,
216
+ self.root,
217
+ self.num_clusters,
218
+ self.pruning,
219
+ self.verbosity_level,
220
+ )
221
+
222
+ subgraph = self.get_subgraph_nodes_edges(
223
+ graph,
224
+ result_vertices,
225
+ {"edges": result_edges, "num_prior_edges": edges_dict["num_prior_edges"]},
226
+ mapping,
227
+ )
228
+
229
+ return subgraph
@@ -0,0 +1,67 @@
1
+ #!/usr/bin/env python3
2
+
3
+ """A utility module for knowledge graph operations"""
4
+
5
+ import networkx as nx
6
+ import pandas as pd
7
+
8
+
9
+ def kg_to_df_pandas(kg: nx.DiGraph) -> tuple[pd.DataFrame, pd.DataFrame]:
10
+ """
11
+ Convert a directed knowledge graph to a pandas DataFrame.
12
+
13
+ Args:
14
+ kg: The directed knowledge graph in networkX format.
15
+
16
+ Returns:
17
+ df_nodes: A pandas DataFrame of the nodes in the knowledge graph.
18
+ df_edges: A pandas DataFrame of the edges in the knowledge graph.
19
+ """
20
+
21
+ # Create a pandas DataFrame of the nodes
22
+ df_nodes = pd.DataFrame.from_dict(kg.nodes, orient="index")
23
+
24
+ # Create a pandas DataFrame of the edges
25
+ df_edges = nx.to_pandas_edgelist(kg, source="node_source", target="node_target")
26
+
27
+ return df_nodes, df_edges
28
+
29
+
30
+ def df_pandas_to_kg(
31
+ df: pd.DataFrame, df_nodes_attrs: pd.DataFrame, node_source: str, node_target: str
32
+ ) -> nx.DiGraph:
33
+ """
34
+ Convert a pandas DataFrame to a directed knowledge graph.
35
+
36
+ Args:
37
+ df: A pandas DataFrame of the edges in the knowledge graph.
38
+ df_nodes_attrs: A pandas DataFrame of the nodes in the knowledge graph.
39
+ node_source: The column name of the source node in the df.
40
+ node_target: The column name of the target node in the df.
41
+
42
+ Returns:
43
+ kg: The directed knowledge graph in networkX format.
44
+ """
45
+
46
+ # Assert if the columns node_source and node_target are in the df
47
+ assert node_source in df.columns, f"{node_source} not in df"
48
+ assert node_target in df.columns, f"{node_target} not in df"
49
+
50
+ # Assert that the nodes in the index of the df_nodes_attrs
51
+ # are present in the source and target columns of the df
52
+ assert set(df_nodes_attrs.index).issubset(set(df[node_source]).union(set(df[node_target]))), (
53
+ "Nodes in index of df_nodes not found in df_edges"
54
+ )
55
+
56
+ # Create a knowledge graph from the dataframes
57
+ # Add edges and nodes to the knowledge graph
58
+ kg = nx.from_pandas_edgelist(
59
+ df,
60
+ source=node_source,
61
+ target=node_target,
62
+ create_using=nx.DiGraph,
63
+ edge_attr=True,
64
+ )
65
+ kg.add_nodes_from(df_nodes_attrs.to_dict("index").items())
66
+
67
+ return kg