aiagents4pharma 1.43.0__py3-none-any.whl → 1.45.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (290) hide show
  1. aiagents4pharma/__init__.py +2 -2
  2. aiagents4pharma/talk2aiagents4pharma/.dockerignore +13 -0
  3. aiagents4pharma/talk2aiagents4pharma/Dockerfile +105 -0
  4. aiagents4pharma/talk2aiagents4pharma/README.md +1 -0
  5. aiagents4pharma/talk2aiagents4pharma/__init__.py +4 -5
  6. aiagents4pharma/talk2aiagents4pharma/agents/__init__.py +3 -2
  7. aiagents4pharma/talk2aiagents4pharma/agents/main_agent.py +24 -23
  8. aiagents4pharma/talk2aiagents4pharma/configs/__init__.py +2 -2
  9. aiagents4pharma/talk2aiagents4pharma/configs/agents/__init__.py +2 -2
  10. aiagents4pharma/talk2aiagents4pharma/configs/agents/main_agent/default.yaml +2 -2
  11. aiagents4pharma/talk2aiagents4pharma/configs/config.yaml +1 -1
  12. aiagents4pharma/talk2aiagents4pharma/docker-compose/cpu/.env.example +23 -0
  13. aiagents4pharma/talk2aiagents4pharma/docker-compose/cpu/docker-compose.yml +93 -0
  14. aiagents4pharma/talk2aiagents4pharma/docker-compose/gpu/.env.example +23 -0
  15. aiagents4pharma/talk2aiagents4pharma/docker-compose/gpu/docker-compose.yml +108 -0
  16. aiagents4pharma/talk2aiagents4pharma/install.md +127 -0
  17. aiagents4pharma/talk2aiagents4pharma/states/__init__.py +3 -2
  18. aiagents4pharma/talk2aiagents4pharma/states/state_talk2aiagents4pharma.py +5 -3
  19. aiagents4pharma/talk2aiagents4pharma/tests/__init__.py +2 -2
  20. aiagents4pharma/talk2aiagents4pharma/tests/test_main_agent.py +72 -50
  21. aiagents4pharma/talk2biomodels/.dockerignore +13 -0
  22. aiagents4pharma/talk2biomodels/Dockerfile +104 -0
  23. aiagents4pharma/talk2biomodels/README.md +1 -0
  24. aiagents4pharma/talk2biomodels/__init__.py +4 -8
  25. aiagents4pharma/talk2biomodels/agents/__init__.py +3 -2
  26. aiagents4pharma/talk2biomodels/agents/t2b_agent.py +47 -42
  27. aiagents4pharma/talk2biomodels/api/__init__.py +4 -5
  28. aiagents4pharma/talk2biomodels/api/kegg.py +14 -10
  29. aiagents4pharma/talk2biomodels/api/ols.py +13 -10
  30. aiagents4pharma/talk2biomodels/api/uniprot.py +7 -6
  31. aiagents4pharma/talk2biomodels/configs/__init__.py +3 -4
  32. aiagents4pharma/talk2biomodels/configs/agents/__init__.py +2 -2
  33. aiagents4pharma/talk2biomodels/configs/agents/t2b_agent/__init__.py +2 -2
  34. aiagents4pharma/talk2biomodels/configs/agents/t2b_agent/default.yaml +1 -1
  35. aiagents4pharma/talk2biomodels/configs/config.yaml +1 -1
  36. aiagents4pharma/talk2biomodels/configs/tools/__init__.py +4 -5
  37. aiagents4pharma/talk2biomodels/configs/tools/ask_question/__init__.py +2 -2
  38. aiagents4pharma/talk2biomodels/configs/tools/ask_question/default.yaml +1 -2
  39. aiagents4pharma/talk2biomodels/configs/tools/custom_plotter/__init__.py +2 -2
  40. aiagents4pharma/talk2biomodels/configs/tools/custom_plotter/default.yaml +1 -1
  41. aiagents4pharma/talk2biomodels/configs/tools/get_annotation/__init__.py +2 -2
  42. aiagents4pharma/talk2biomodels/configs/tools/get_annotation/default.yaml +1 -1
  43. aiagents4pharma/talk2biomodels/install.md +63 -0
  44. aiagents4pharma/talk2biomodels/models/__init__.py +4 -4
  45. aiagents4pharma/talk2biomodels/models/basico_model.py +36 -28
  46. aiagents4pharma/talk2biomodels/models/sys_bio_model.py +13 -10
  47. aiagents4pharma/talk2biomodels/states/__init__.py +3 -2
  48. aiagents4pharma/talk2biomodels/states/state_talk2biomodels.py +12 -8
  49. aiagents4pharma/talk2biomodels/tests/BIOMD0000000449_url.xml +1585 -0
  50. aiagents4pharma/talk2biomodels/tests/__init__.py +2 -2
  51. aiagents4pharma/talk2biomodels/tests/article_on_model_537.pdf +0 -0
  52. aiagents4pharma/talk2biomodels/tests/test_api.py +18 -14
  53. aiagents4pharma/talk2biomodels/tests/test_ask_question.py +8 -9
  54. aiagents4pharma/talk2biomodels/tests/test_basico_model.py +15 -9
  55. aiagents4pharma/talk2biomodels/tests/test_get_annotation.py +54 -55
  56. aiagents4pharma/talk2biomodels/tests/test_getmodelinfo.py +28 -27
  57. aiagents4pharma/talk2biomodels/tests/test_integration.py +21 -33
  58. aiagents4pharma/talk2biomodels/tests/test_load_biomodel.py +14 -11
  59. aiagents4pharma/talk2biomodels/tests/test_param_scan.py +21 -20
  60. aiagents4pharma/talk2biomodels/tests/test_query_article.py +129 -29
  61. aiagents4pharma/talk2biomodels/tests/test_search_models.py +9 -13
  62. aiagents4pharma/talk2biomodels/tests/test_simulate_model.py +16 -15
  63. aiagents4pharma/talk2biomodels/tests/test_steady_state.py +12 -22
  64. aiagents4pharma/talk2biomodels/tests/test_sys_bio_model.py +33 -29
  65. aiagents4pharma/talk2biomodels/tools/__init__.py +15 -12
  66. aiagents4pharma/talk2biomodels/tools/ask_question.py +42 -32
  67. aiagents4pharma/talk2biomodels/tools/custom_plotter.py +51 -43
  68. aiagents4pharma/talk2biomodels/tools/get_annotation.py +99 -75
  69. aiagents4pharma/talk2biomodels/tools/get_modelinfo.py +57 -51
  70. aiagents4pharma/talk2biomodels/tools/load_arguments.py +52 -32
  71. aiagents4pharma/talk2biomodels/tools/load_biomodel.py +8 -2
  72. aiagents4pharma/talk2biomodels/tools/parameter_scan.py +107 -90
  73. aiagents4pharma/talk2biomodels/tools/query_article.py +14 -13
  74. aiagents4pharma/talk2biomodels/tools/search_models.py +37 -26
  75. aiagents4pharma/talk2biomodels/tools/simulate_model.py +47 -37
  76. aiagents4pharma/talk2biomodels/tools/steady_state.py +76 -58
  77. aiagents4pharma/talk2biomodels/tools/utils.py +4 -3
  78. aiagents4pharma/talk2cells/README.md +1 -0
  79. aiagents4pharma/talk2cells/__init__.py +4 -5
  80. aiagents4pharma/talk2cells/agents/__init__.py +3 -2
  81. aiagents4pharma/talk2cells/agents/scp_agent.py +21 -19
  82. aiagents4pharma/talk2cells/states/__init__.py +3 -2
  83. aiagents4pharma/talk2cells/states/state_talk2cells.py +4 -2
  84. aiagents4pharma/talk2cells/tests/scp_agent/test_scp_agent.py +8 -9
  85. aiagents4pharma/talk2cells/tools/__init__.py +3 -2
  86. aiagents4pharma/talk2cells/tools/scp_agent/__init__.py +4 -4
  87. aiagents4pharma/talk2cells/tools/scp_agent/display_studies.py +5 -3
  88. aiagents4pharma/talk2cells/tools/scp_agent/search_studies.py +21 -22
  89. aiagents4pharma/talk2knowledgegraphs/.dockerignore +13 -0
  90. aiagents4pharma/talk2knowledgegraphs/Dockerfile +103 -0
  91. aiagents4pharma/talk2knowledgegraphs/README.md +1 -0
  92. aiagents4pharma/talk2knowledgegraphs/__init__.py +4 -7
  93. aiagents4pharma/talk2knowledgegraphs/agents/__init__.py +3 -2
  94. aiagents4pharma/talk2knowledgegraphs/agents/t2kg_agent.py +40 -30
  95. aiagents4pharma/talk2knowledgegraphs/configs/__init__.py +3 -6
  96. aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/__init__.py +2 -2
  97. aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/default.yaml +8 -8
  98. aiagents4pharma/talk2knowledgegraphs/configs/app/__init__.py +3 -2
  99. aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/__init__.py +2 -2
  100. aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/default.yaml +1 -1
  101. aiagents4pharma/talk2knowledgegraphs/configs/config.yaml +1 -1
  102. aiagents4pharma/talk2knowledgegraphs/configs/tools/__init__.py +4 -5
  103. aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/__init__.py +2 -2
  104. aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/default.yaml +1 -1
  105. aiagents4pharma/talk2knowledgegraphs/configs/tools/multimodal_subgraph_extraction/default.yaml +17 -2
  106. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/__init__.py +2 -2
  107. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/default.yaml +1 -1
  108. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/__init__.py +2 -2
  109. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/default.yaml +1 -1
  110. aiagents4pharma/talk2knowledgegraphs/configs/utils/enrichments/ols_terms/default.yaml +1 -1
  111. aiagents4pharma/talk2knowledgegraphs/configs/utils/enrichments/reactome_pathways/default.yaml +1 -1
  112. aiagents4pharma/talk2knowledgegraphs/configs/utils/enrichments/uniprot_proteins/default.yaml +1 -1
  113. aiagents4pharma/talk2knowledgegraphs/configs/utils/pubchem_utils/default.yaml +1 -1
  114. aiagents4pharma/talk2knowledgegraphs/datasets/__init__.py +4 -6
  115. aiagents4pharma/talk2knowledgegraphs/datasets/biobridge_primekg.py +115 -67
  116. aiagents4pharma/talk2knowledgegraphs/datasets/dataset.py +2 -0
  117. aiagents4pharma/talk2knowledgegraphs/datasets/primekg.py +35 -24
  118. aiagents4pharma/talk2knowledgegraphs/datasets/starkqa_primekg.py +29 -21
  119. aiagents4pharma/talk2knowledgegraphs/docker-compose/cpu/.env.example +23 -0
  120. aiagents4pharma/talk2knowledgegraphs/docker-compose/cpu/docker-compose.yml +93 -0
  121. aiagents4pharma/talk2knowledgegraphs/docker-compose/gpu/.env.example +23 -0
  122. aiagents4pharma/talk2knowledgegraphs/docker-compose/gpu/docker-compose.yml +108 -0
  123. aiagents4pharma/talk2knowledgegraphs/entrypoint.sh +190 -0
  124. aiagents4pharma/talk2knowledgegraphs/install.md +140 -0
  125. aiagents4pharma/talk2knowledgegraphs/milvus_data_dump.py +31 -65
  126. aiagents4pharma/talk2knowledgegraphs/states/__init__.py +3 -2
  127. aiagents4pharma/talk2knowledgegraphs/states/state_talk2knowledgegraphs.py +1 -0
  128. aiagents4pharma/talk2knowledgegraphs/tests/test_agents_t2kg_agent.py +65 -40
  129. aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_biobridge_primekg.py +54 -48
  130. aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_dataset.py +4 -0
  131. aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_primekg.py +17 -4
  132. aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_starkqa_primekg.py +33 -24
  133. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_graphrag_reasoning.py +116 -69
  134. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_milvus_multimodal_subgraph_extraction.py +736 -413
  135. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_multimodal_subgraph_extraction.py +22 -15
  136. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_extraction.py +19 -12
  137. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_summarization.py +95 -48
  138. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_embeddings.py +4 -0
  139. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_huggingface.py +5 -0
  140. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_nim_molmim.py +13 -18
  141. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_ollama.py +10 -3
  142. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_enrichments.py +4 -3
  143. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ollama.py +3 -2
  144. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ols.py +1 -0
  145. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_pubchem.py +9 -4
  146. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_reactome.py +6 -6
  147. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_uniprot.py +4 -0
  148. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_extractions_milvus_multimodal_pcst.py +442 -42
  149. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_kg_utils.py +3 -4
  150. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_pubchem_utils.py +10 -6
  151. aiagents4pharma/talk2knowledgegraphs/tools/__init__.py +10 -7
  152. aiagents4pharma/talk2knowledgegraphs/tools/graphrag_reasoning.py +15 -20
  153. aiagents4pharma/talk2knowledgegraphs/tools/milvus_multimodal_subgraph_extraction.py +245 -205
  154. aiagents4pharma/talk2knowledgegraphs/tools/multimodal_subgraph_extraction.py +92 -90
  155. aiagents4pharma/talk2knowledgegraphs/tools/subgraph_extraction.py +25 -37
  156. aiagents4pharma/talk2knowledgegraphs/tools/subgraph_summarization.py +10 -13
  157. aiagents4pharma/talk2knowledgegraphs/utils/__init__.py +4 -7
  158. aiagents4pharma/talk2knowledgegraphs/utils/embeddings/__init__.py +4 -7
  159. aiagents4pharma/talk2knowledgegraphs/utils/embeddings/embeddings.py +4 -0
  160. aiagents4pharma/talk2knowledgegraphs/utils/embeddings/huggingface.py +11 -14
  161. aiagents4pharma/talk2knowledgegraphs/utils/embeddings/nim_molmim.py +7 -7
  162. aiagents4pharma/talk2knowledgegraphs/utils/embeddings/ollama.py +12 -6
  163. aiagents4pharma/talk2knowledgegraphs/utils/embeddings/sentence_transformer.py +8 -6
  164. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/__init__.py +9 -6
  165. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/enrichments.py +1 -0
  166. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/ollama.py +15 -9
  167. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/ols_terms.py +23 -20
  168. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/pubchem_strings.py +12 -10
  169. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/reactome_pathways.py +16 -10
  170. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/uniprot_proteins.py +26 -18
  171. aiagents4pharma/talk2knowledgegraphs/utils/extractions/__init__.py +4 -5
  172. aiagents4pharma/talk2knowledgegraphs/utils/extractions/milvus_multimodal_pcst.py +218 -81
  173. aiagents4pharma/talk2knowledgegraphs/utils/extractions/multimodal_pcst.py +53 -47
  174. aiagents4pharma/talk2knowledgegraphs/utils/extractions/pcst.py +18 -14
  175. aiagents4pharma/talk2knowledgegraphs/utils/kg_utils.py +22 -23
  176. aiagents4pharma/talk2knowledgegraphs/utils/pubchem_utils.py +11 -10
  177. aiagents4pharma/talk2scholars/.dockerignore +13 -0
  178. aiagents4pharma/talk2scholars/Dockerfile +104 -0
  179. aiagents4pharma/talk2scholars/README.md +1 -0
  180. aiagents4pharma/talk2scholars/agents/__init__.py +1 -5
  181. aiagents4pharma/talk2scholars/agents/main_agent.py +6 -4
  182. aiagents4pharma/talk2scholars/agents/paper_download_agent.py +5 -4
  183. aiagents4pharma/talk2scholars/agents/pdf_agent.py +4 -2
  184. aiagents4pharma/talk2scholars/agents/s2_agent.py +2 -2
  185. aiagents4pharma/talk2scholars/agents/zotero_agent.py +10 -11
  186. aiagents4pharma/talk2scholars/configs/__init__.py +1 -3
  187. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/__init__.py +1 -4
  188. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/main_agent/default.yaml +1 -1
  189. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/pdf_agent/default.yaml +1 -1
  190. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/s2_agent/default.yaml +8 -8
  191. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/zotero_agent/default.yaml +7 -7
  192. aiagents4pharma/talk2scholars/configs/tools/__init__.py +8 -6
  193. aiagents4pharma/talk2scholars/docker-compose/cpu/.env.example +21 -0
  194. aiagents4pharma/talk2scholars/docker-compose/cpu/docker-compose.yml +90 -0
  195. aiagents4pharma/talk2scholars/docker-compose/gpu/.env.example +21 -0
  196. aiagents4pharma/talk2scholars/docker-compose/gpu/docker-compose.yml +105 -0
  197. aiagents4pharma/talk2scholars/install.md +122 -0
  198. aiagents4pharma/talk2scholars/state/state_talk2scholars.py +8 -8
  199. aiagents4pharma/talk2scholars/tests/{test_main_agent.py → test_agents_main_agent.py} +41 -23
  200. aiagents4pharma/talk2scholars/tests/{test_paper_download_agent.py → test_agents_paper_agents_download_agent.py} +10 -16
  201. aiagents4pharma/talk2scholars/tests/{test_pdf_agent.py → test_agents_pdf_agent.py} +6 -10
  202. aiagents4pharma/talk2scholars/tests/{test_s2_agent.py → test_agents_s2_agent.py} +8 -16
  203. aiagents4pharma/talk2scholars/tests/{test_zotero_agent.py → test_agents_zotero_agent.py} +5 -7
  204. aiagents4pharma/talk2scholars/tests/{test_s2_display_dataframe.py → test_s2_tools_display_dataframe.py} +6 -7
  205. aiagents4pharma/talk2scholars/tests/{test_s2_query_dataframe.py → test_s2_tools_query_dataframe.py} +5 -15
  206. aiagents4pharma/talk2scholars/tests/{test_paper_downloader.py → test_tools_paper_downloader.py} +25 -63
  207. aiagents4pharma/talk2scholars/tests/{test_question_and_answer_tool.py → test_tools_question_and_answer_tool.py} +2 -6
  208. aiagents4pharma/talk2scholars/tests/{test_s2_multi.py → test_tools_s2_multi.py} +5 -5
  209. aiagents4pharma/talk2scholars/tests/{test_s2_retrieve.py → test_tools_s2_retrieve.py} +2 -1
  210. aiagents4pharma/talk2scholars/tests/{test_s2_search.py → test_tools_s2_search.py} +5 -5
  211. aiagents4pharma/talk2scholars/tests/{test_s2_single.py → test_tools_s2_single.py} +5 -5
  212. aiagents4pharma/talk2scholars/tests/{test_arxiv_downloader.py → test_utils_arxiv_downloader.py} +16 -25
  213. aiagents4pharma/talk2scholars/tests/{test_base_paper_downloader.py → test_utils_base_paper_downloader.py} +25 -47
  214. aiagents4pharma/talk2scholars/tests/{test_biorxiv_downloader.py → test_utils_biorxiv_downloader.py} +14 -42
  215. aiagents4pharma/talk2scholars/tests/{test_medrxiv_downloader.py → test_utils_medrxiv_downloader.py} +15 -49
  216. aiagents4pharma/talk2scholars/tests/{test_nvidia_nim_reranker.py → test_utils_nvidia_nim_reranker.py} +6 -16
  217. aiagents4pharma/talk2scholars/tests/{test_pdf_answer_formatter.py → test_utils_pdf_answer_formatter.py} +1 -0
  218. aiagents4pharma/talk2scholars/tests/{test_pdf_batch_processor.py → test_utils_pdf_batch_processor.py} +6 -15
  219. aiagents4pharma/talk2scholars/tests/{test_pdf_collection_manager.py → test_utils_pdf_collection_manager.py} +34 -11
  220. aiagents4pharma/talk2scholars/tests/{test_pdf_document_processor.py → test_utils_pdf_document_processor.py} +2 -3
  221. aiagents4pharma/talk2scholars/tests/{test_pdf_generate_answer.py → test_utils_pdf_generate_answer.py} +3 -6
  222. aiagents4pharma/talk2scholars/tests/{test_pdf_gpu_detection.py → test_utils_pdf_gpu_detection.py} +5 -16
  223. aiagents4pharma/talk2scholars/tests/{test_pdf_rag_pipeline.py → test_utils_pdf_rag_pipeline.py} +7 -17
  224. aiagents4pharma/talk2scholars/tests/{test_pdf_retrieve_chunks.py → test_utils_pdf_retrieve_chunks.py} +4 -11
  225. aiagents4pharma/talk2scholars/tests/{test_pdf_singleton_manager.py → test_utils_pdf_singleton_manager.py} +26 -23
  226. aiagents4pharma/talk2scholars/tests/{test_pdf_vector_normalization.py → test_utils_pdf_vector_normalization.py} +1 -1
  227. aiagents4pharma/talk2scholars/tests/{test_pdf_vector_store.py → test_utils_pdf_vector_store.py} +27 -55
  228. aiagents4pharma/talk2scholars/tests/{test_pubmed_downloader.py → test_utils_pubmed_downloader.py} +31 -91
  229. aiagents4pharma/talk2scholars/tests/{test_read_helper_utils.py → test_utils_read_helper_utils.py} +2 -6
  230. aiagents4pharma/talk2scholars/tests/{test_s2_utils_ext_ids.py → test_utils_s2_utils_ext_ids.py} +5 -15
  231. aiagents4pharma/talk2scholars/tests/{test_zotero_human_in_the_loop.py → test_utils_zotero_human_in_the_loop.py} +6 -13
  232. aiagents4pharma/talk2scholars/tests/{test_zotero_path.py → test_utils_zotero_path.py} +53 -45
  233. aiagents4pharma/talk2scholars/tests/{test_zotero_read.py → test_utils_zotero_read.py} +30 -91
  234. aiagents4pharma/talk2scholars/tests/{test_zotero_write.py → test_utils_zotero_write.py} +6 -16
  235. aiagents4pharma/talk2scholars/tools/__init__.py +1 -4
  236. aiagents4pharma/talk2scholars/tools/paper_download/paper_downloader.py +20 -35
  237. aiagents4pharma/talk2scholars/tools/paper_download/utils/__init__.py +7 -5
  238. aiagents4pharma/talk2scholars/tools/paper_download/utils/arxiv_downloader.py +9 -11
  239. aiagents4pharma/talk2scholars/tools/paper_download/utils/base_paper_downloader.py +14 -21
  240. aiagents4pharma/talk2scholars/tools/paper_download/utils/biorxiv_downloader.py +14 -22
  241. aiagents4pharma/talk2scholars/tools/paper_download/utils/medrxiv_downloader.py +11 -13
  242. aiagents4pharma/talk2scholars/tools/paper_download/utils/pubmed_downloader.py +14 -28
  243. aiagents4pharma/talk2scholars/tools/pdf/question_and_answer.py +4 -8
  244. aiagents4pharma/talk2scholars/tools/pdf/utils/__init__.py +16 -14
  245. aiagents4pharma/talk2scholars/tools/pdf/utils/answer_formatter.py +4 -4
  246. aiagents4pharma/talk2scholars/tools/pdf/utils/batch_processor.py +15 -17
  247. aiagents4pharma/talk2scholars/tools/pdf/utils/collection_manager.py +2 -2
  248. aiagents4pharma/talk2scholars/tools/pdf/utils/document_processor.py +5 -5
  249. aiagents4pharma/talk2scholars/tools/pdf/utils/generate_answer.py +4 -4
  250. aiagents4pharma/talk2scholars/tools/pdf/utils/get_vectorstore.py +2 -6
  251. aiagents4pharma/talk2scholars/tools/pdf/utils/gpu_detection.py +5 -9
  252. aiagents4pharma/talk2scholars/tools/pdf/utils/nvidia_nim_reranker.py +4 -4
  253. aiagents4pharma/talk2scholars/tools/pdf/utils/paper_loader.py +2 -2
  254. aiagents4pharma/talk2scholars/tools/pdf/utils/rag_pipeline.py +6 -15
  255. aiagents4pharma/talk2scholars/tools/pdf/utils/retrieve_chunks.py +7 -15
  256. aiagents4pharma/talk2scholars/tools/pdf/utils/singleton_manager.py +2 -2
  257. aiagents4pharma/talk2scholars/tools/pdf/utils/tool_helper.py +3 -4
  258. aiagents4pharma/talk2scholars/tools/pdf/utils/vector_normalization.py +8 -17
  259. aiagents4pharma/talk2scholars/tools/pdf/utils/vector_store.py +17 -33
  260. aiagents4pharma/talk2scholars/tools/s2/__init__.py +8 -6
  261. aiagents4pharma/talk2scholars/tools/s2/display_dataframe.py +3 -7
  262. aiagents4pharma/talk2scholars/tools/s2/multi_paper_rec.py +7 -6
  263. aiagents4pharma/talk2scholars/tools/s2/query_dataframe.py +5 -12
  264. aiagents4pharma/talk2scholars/tools/s2/retrieve_semantic_scholar_paper_id.py +2 -4
  265. aiagents4pharma/talk2scholars/tools/s2/search.py +6 -6
  266. aiagents4pharma/talk2scholars/tools/s2/single_paper_rec.py +5 -3
  267. aiagents4pharma/talk2scholars/tools/s2/utils/__init__.py +1 -3
  268. aiagents4pharma/talk2scholars/tools/s2/utils/multi_helper.py +12 -18
  269. aiagents4pharma/talk2scholars/tools/s2/utils/search_helper.py +11 -18
  270. aiagents4pharma/talk2scholars/tools/s2/utils/single_helper.py +11 -16
  271. aiagents4pharma/talk2scholars/tools/zotero/__init__.py +1 -4
  272. aiagents4pharma/talk2scholars/tools/zotero/utils/__init__.py +1 -4
  273. aiagents4pharma/talk2scholars/tools/zotero/utils/read_helper.py +21 -39
  274. aiagents4pharma/talk2scholars/tools/zotero/utils/review_helper.py +2 -6
  275. aiagents4pharma/talk2scholars/tools/zotero/utils/write_helper.py +8 -11
  276. aiagents4pharma/talk2scholars/tools/zotero/utils/zotero_path.py +4 -12
  277. aiagents4pharma/talk2scholars/tools/zotero/utils/zotero_pdf_downloader.py +13 -27
  278. aiagents4pharma/talk2scholars/tools/zotero/zotero_read.py +4 -7
  279. aiagents4pharma/talk2scholars/tools/zotero/zotero_review.py +8 -10
  280. aiagents4pharma/talk2scholars/tools/zotero/zotero_write.py +3 -2
  281. {aiagents4pharma-1.43.0.dist-info → aiagents4pharma-1.45.0.dist-info}/METADATA +115 -50
  282. aiagents4pharma-1.45.0.dist-info/RECORD +324 -0
  283. {aiagents4pharma-1.43.0.dist-info → aiagents4pharma-1.45.0.dist-info}/WHEEL +1 -2
  284. aiagents4pharma-1.43.0.dist-info/RECORD +0 -293
  285. aiagents4pharma-1.43.0.dist-info/top_level.txt +0 -1
  286. /aiagents4pharma/talk2scholars/tests/{test_state.py → test_states_state.py} +0 -0
  287. /aiagents4pharma/talk2scholars/tests/{test_pdf_paper_loader.py → test_utils_pdf_paper_loader.py} +0 -0
  288. /aiagents4pharma/talk2scholars/tests/{test_tool_helper_utils.py → test_utils_tool_helper_utils.py} +0 -0
  289. /aiagents4pharma/talk2scholars/tests/{test_zotero_pdf_downloader_utils.py → test_utils_zotero_pdf_downloader_utils.py} +0 -0
  290. {aiagents4pharma-1.43.0.dist-info → aiagents4pharma-1.45.0.dist-info}/licenses/LICENSE +0 -0
@@ -2,27 +2,25 @@
2
2
  Tool for performing multimodal subgraph extraction.
3
3
  """
4
4
 
5
- # import datetime
6
- from typing import Type, Annotated
7
5
  import logging
6
+ from typing import Annotated
7
+
8
8
  import hydra
9
9
  import pandas as pd
10
- from pydantic import BaseModel, Field
11
- from langchain_core.tools import BaseTool
12
10
  from langchain_core.messages import ToolMessage
11
+ from langchain_core.tools import BaseTool
13
12
  from langchain_core.tools.base import InjectedToolCallId
14
- from langgraph.types import Command
15
13
  from langgraph.prebuilt import InjectedState
14
+ from langgraph.types import Command
15
+ from pydantic import BaseModel, Field
16
16
  from pymilvus import Collection
17
- from ..utils.extractions.milvus_multimodal_pcst import MultimodalPCSTPruning
17
+
18
+ from ..utils.extractions.milvus_multimodal_pcst import (
19
+ DynamicLibraryLoader,
20
+ MultimodalPCSTPruning,
21
+ SystemDetector,
22
+ )
18
23
  from .load_arguments import ArgumentData
19
- try:
20
- import cupy as py
21
- import cudf
22
- df = cudf
23
- except ImportError:
24
- import numpy as py
25
- df = pd
26
24
 
27
25
  # Initialize logger
28
26
  logging.basicConfig(level=logging.INFO)
@@ -41,14 +39,10 @@ class MultimodalSubgraphExtractionInput(BaseModel):
41
39
  arg_data: Argument for analytical process over graph data.
42
40
  """
43
41
 
44
- tool_call_id: Annotated[str, InjectedToolCallId] = Field(
45
- description="Tool call ID."
46
- )
42
+ tool_call_id: Annotated[str, InjectedToolCallId] = Field(description="Tool call ID.")
47
43
  state: Annotated[dict, InjectedState] = Field(description="Injected state.")
48
44
  prompt: str = Field(description="Prompt to interact with the backend.")
49
- arg_data: ArgumentData = Field(
50
- description="Experiment over graph data.", default=None
51
- )
45
+ arg_data: ArgumentData = Field(description="Experiment over graph data.", default=None)
52
46
 
53
47
 
54
48
  class MultimodalSubgraphExtractionTool(BaseTool):
@@ -59,10 +53,19 @@ class MultimodalSubgraphExtractionTool(BaseTool):
59
53
 
60
54
  name: str = "subgraph_extraction"
61
55
  description: str = "A tool for subgraph extraction based on user's prompt."
62
- args_schema: Type[BaseModel] = MultimodalSubgraphExtractionInput
56
+ args_schema: type[BaseModel] = MultimodalSubgraphExtractionInput
57
+
58
+ def __init__(self, **kwargs):
59
+ super().__init__(**kwargs)
60
+ # Initialize hardware detection and dynamic library loading
61
+ object.__setattr__(self, "detector", SystemDetector())
62
+ object.__setattr__(self, "loader", DynamicLibraryLoader(self.detector))
63
+ logger.info(
64
+ "MultimodalSubgraphExtractionTool initialized with %s mode",
65
+ "GPU" if self.loader.use_gpu else "CPU",
66
+ )
63
67
 
64
- def _read_multimodal_files(self,
65
- state: Annotated[dict, InjectedState]) -> df.DataFrame:
68
+ def _read_multimodal_files(self, state: Annotated[dict, InjectedState]):
66
69
  """
67
70
  Read the uploaded multimodal files and return a DataFrame.
68
71
 
@@ -72,7 +75,7 @@ class MultimodalSubgraphExtractionTool(BaseTool):
72
75
  Returns:
73
76
  A DataFrame containing the multimodal files.
74
77
  """
75
- multimodal_df = df.DataFrame({"name": [], "node_type": []})
78
+ multimodal_df = self.loader.df.DataFrame({"name": [], "node_type": []})
76
79
 
77
80
  # Loop over the uploaded files and find multimodal files
78
81
  logger.log(logging.INFO, "Looping over uploaded files")
@@ -80,8 +83,9 @@ class MultimodalSubgraphExtractionTool(BaseTool):
80
83
  # Check if multimodal file is uploaded
81
84
  if state["uploaded_files"][i]["file_type"] == "multimodal":
82
85
  # Read the Excel file
83
- multimodal_df = pd.read_excel(state["uploaded_files"][i]["file_path"],
84
- sheet_name=None)
86
+ multimodal_df = pd.read_excel(
87
+ state["uploaded_files"][i]["file_path"], sheet_name=None
88
+ )
85
89
 
86
90
  # Check if the multimodal_df is empty
87
91
  logger.log(logging.INFO, "Checking if multimodal_df is empty")
@@ -90,20 +94,48 @@ class MultimodalSubgraphExtractionTool(BaseTool):
90
94
  logger.log(logging.INFO, "Preparing multimodal_df")
91
95
  # Merge all obtained dataframes into a single dataframe
92
96
  multimodal_df = pd.concat(multimodal_df).reset_index()
93
- multimodal_df = df.DataFrame(multimodal_df)
97
+ multimodal_df = self.loader.df.DataFrame(multimodal_df)
94
98
  multimodal_df.drop(columns=["level_1"], inplace=True)
95
- multimodal_df.rename(columns={"level_0": "q_node_type",
96
- "name": "q_node_name"}, inplace=True)
99
+ multimodal_df.rename(
100
+ columns={"level_0": "q_node_type", "name": "q_node_name"}, inplace=True
101
+ )
97
102
  # Since an excel sheet name could not contain a `/`,
98
103
  # but the node type can be 'gene/protein' as exists in the PrimeKG
99
- multimodal_df["q_node_type"] = multimodal_df["q_node_type"].str.replace('-', '_')
104
+ multimodal_df["q_node_type"] = multimodal_df["q_node_type"].str.replace("-", "_")
100
105
 
101
106
  return multimodal_df
102
107
 
103
- def _prepare_query_modalities(self,
104
- prompt: dict,
105
- state: Annotated[dict, InjectedState],
106
- cfg_db: dict) -> df.DataFrame:
108
+ def _query_milvus_collection(self, node_type, node_type_df, cfg_db):
109
+ """Helper method to query Milvus collection for a specific node type."""
110
+ # Load the collection
111
+ collection = Collection(
112
+ name=f"{cfg_db.milvus_db.database_name}_nodes_{node_type.replace('/', '_')}"
113
+ )
114
+ collection.load()
115
+
116
+ # Query the collection with node names from multimodal_df
117
+ node_names_series = node_type_df["q_node_name"]
118
+ q_node_names = getattr(
119
+ node_names_series, "to_pandas", lambda series=node_names_series: series
120
+ )().tolist()
121
+ q_columns = ["node_id", "node_name", "node_type", "feat", "feat_emb", "desc", "desc_emb"]
122
+ res = collection.query(
123
+ expr=f"node_name IN [{','.join(f'"{name}"' for name in q_node_names)}]",
124
+ output_fields=q_columns,
125
+ )
126
+ # Convert the embeedings into floats
127
+ for r_ in res:
128
+ r_["feat_emb"] = [float(x) for x in r_["feat_emb"]]
129
+ r_["desc_emb"] = [float(x) for x in r_["desc_emb"]]
130
+
131
+ # Convert the result to a DataFrame
132
+ res_df = self.loader.df.DataFrame(res)[q_columns]
133
+ res_df["use_description"] = False
134
+ return res_df
135
+
136
+ def _prepare_query_modalities(
137
+ self, prompt: dict, state: Annotated[dict, InjectedState], cfg_db: dict
138
+ ):
107
139
  """
108
140
  Prepare the modality-specific query for subgraph extraction.
109
141
 
@@ -118,16 +150,18 @@ class MultimodalSubgraphExtractionTool(BaseTool):
118
150
  # Initialize dataframes
119
151
  logger.log(logging.INFO, "Initializing dataframes")
120
152
  query_df = []
121
- prompt_df = df.DataFrame({
122
- 'node_id': 'user_prompt',
123
- 'node_name': 'User Prompt',
124
- 'node_type': 'prompt',
125
- 'feat': prompt["text"],
126
- 'feat_emb': prompt["emb"],
127
- 'desc': prompt["text"],
128
- 'desc_emb': prompt["emb"],
129
- 'use_description': True # set to True for user prompt embedding
130
- })
153
+ prompt_df = self.loader.df.DataFrame(
154
+ {
155
+ "node_id": "user_prompt",
156
+ "node_name": "User Prompt",
157
+ "node_type": "prompt",
158
+ "feat": prompt["text"],
159
+ "feat_emb": prompt["emb"],
160
+ "desc": prompt["text"],
161
+ "desc_emb": prompt["emb"],
162
+ "use_description": True, # set to True for user prompt embedding
163
+ }
164
+ )
131
165
 
132
166
  # Read multimodal files uploaded by the user
133
167
  multimodal_df = self._read_multimodal_files(state)
@@ -136,64 +170,44 @@ class MultimodalSubgraphExtractionTool(BaseTool):
136
170
  logger.log(logging.INFO, "Prepare query modalities")
137
171
  if len(multimodal_df) > 0:
138
172
  # Query the Milvus database for each node type in multimodal_df
139
- logger.log(logging.INFO, "Querying Milvus database for each node type in multimodal_df")
173
+ logger.log(
174
+ logging.INFO,
175
+ "Querying Milvus database for each node type in multimodal_df",
176
+ )
140
177
  for node_type, node_type_df in multimodal_df.groupby("q_node_type"):
141
178
  print(f"Processing node type: {node_type}")
142
-
143
- # Load the collection
144
- collection = Collection(
145
- name=f"{cfg_db.milvus_db.database_name}_nodes_{node_type.replace('/', '_')}"
146
- )
147
- collection.load()
148
-
149
- # Query the collection with node names from multimodal_df
150
- q_node_names = getattr(node_type_df['q_node_name'],
151
- "to_pandas",
152
- lambda: node_type_df['q_node_name'])().tolist()
153
- q_columns = ["node_id", "node_name", "node_type",
154
- "feat", "feat_emb", "desc", "desc_emb"]
155
- res = collection.query(
156
- expr=f'node_name IN [{','.join(f'"{name}"' for name in q_node_names)}]',
157
- output_fields=q_columns,
158
- )
159
- # Convert the embeedings into floats
160
- for r_ in res:
161
- r_['feat_emb'] = [float(x) for x in r_['feat_emb']]
162
- r_['desc_emb'] = [float(x) for x in r_['desc_emb']]
163
-
164
- # Convert the result to a DataFrame
165
- res_df = df.DataFrame(res)[q_columns]
166
- res_df["use_description"] = False
167
-
168
- # Append the results to query_df
179
+ res_df = self._query_milvus_collection(node_type, node_type_df, cfg_db)
169
180
  query_df.append(res_df)
170
181
 
171
182
  # Concatenate all results into a single DataFrame
172
183
  logger.log(logging.INFO, "Concatenating all results into a single DataFrame")
173
- query_df = df.concat(query_df, ignore_index=True)
184
+ query_df = self.loader.df.concat(query_df, ignore_index=True)
174
185
 
175
186
  # Update the state by adding the the selected node IDs
176
187
  logger.log(logging.INFO, "Updating state with selected node IDs")
177
- state["selections"] = getattr(query_df,
178
- "to_pandas",
179
- lambda: query_df)().groupby(
180
- "node_type"
181
- )["node_id"].apply(list).to_dict()
188
+ state["selections"] = (
189
+ getattr(query_df, "to_pandas", lambda: query_df)()
190
+ .groupby("node_type")["node_id"]
191
+ .apply(list)
192
+ .to_dict()
193
+ )
182
194
 
183
195
  # Append a user prompt to the query dataframe
184
196
  logger.log(logging.INFO, "Adding user prompt to query dataframe")
185
- query_df = df.concat([query_df, prompt_df]).reset_index(drop=True)
197
+ query_df = self.loader.df.concat([query_df, prompt_df]).reset_index(drop=True)
186
198
  else:
187
199
  # If no multimodal files are uploaded, use the prompt embeddings
188
200
  query_df = prompt_df
189
201
 
190
202
  return query_df
191
203
 
192
- def _perform_subgraph_extraction(self,
193
- state: Annotated[dict, InjectedState],
194
- cfg: dict,
195
- cfg_db: dict,
196
- query_df: pd.DataFrame) -> dict:
204
+ def _perform_subgraph_extraction(
205
+ self,
206
+ state: Annotated[dict, InjectedState],
207
+ cfg: dict,
208
+ cfg_db: dict,
209
+ query_df: pd.DataFrame,
210
+ ) -> dict:
197
211
  """
198
212
  Perform multimodal subgraph extraction based on modal-specific embeddings.
199
213
 
@@ -208,10 +222,7 @@ class MultimodalSubgraphExtractionTool(BaseTool):
208
222
  """
209
223
  # Initialize the subgraph dictionary
210
224
  subgraphs = []
211
- unified_subgraph = {
212
- "nodes": [],
213
- "edges": []
214
- }
225
+ unified_subgraph = {"nodes": [], "edges": []}
215
226
  # subgraphs = {}
216
227
  # subgraphs["nodes"] = []
217
228
  # subgraphs["edges"] = []
@@ -219,10 +230,22 @@ class MultimodalSubgraphExtractionTool(BaseTool):
219
230
  # Loop over query embeddings and modalities
220
231
  for q in getattr(query_df, "to_pandas", lambda: query_df)().iterrows():
221
232
  logger.log(logging.INFO, "===========================================")
222
- logger.log(logging.INFO, "Processing query: %s", q[1]['node_name'])
233
+ logger.log(logging.INFO, "Processing query: %s", q[1]["node_name"])
223
234
  # Prepare the PCSTPruning object and extract the subgraph
224
235
  # Parameters were set in the configuration file obtained from Hydra
225
236
  # start = datetime.datetime.now()
237
+ # Get dynamic metric type (overrides any config setting)
238
+ # Get dynamic metric type (overrides any config setting)
239
+ has_vector_processing = hasattr(cfg, "vector_processing")
240
+ if has_vector_processing:
241
+ dynamic_metrics_enabled = getattr(cfg.vector_processing, "dynamic_metrics", True)
242
+ else:
243
+ dynamic_metrics_enabled = False
244
+ if has_vector_processing and dynamic_metrics_enabled:
245
+ dynamic_metric_type = self.loader.metric_type
246
+ else:
247
+ dynamic_metric_type = getattr(cfg, "search_metric_type", self.loader.metric_type)
248
+
226
249
  subgraph = MultimodalPCSTPruning(
227
250
  topk=state["topk_nodes"],
228
251
  topk_e=state["topk_edges"],
@@ -232,49 +255,51 @@ class MultimodalSubgraphExtractionTool(BaseTool):
232
255
  num_clusters=cfg.num_clusters,
233
256
  pruning=cfg.pruning,
234
257
  verbosity_level=cfg.verbosity_level,
235
- use_description=q[1]['use_description'],
236
- metric_type=cfg.search_metric_type
237
- ).extract_subgraph(q[1]['desc_emb'],
238
- q[1]['feat_emb'],
239
- q[1]['node_type'],
240
- cfg_db)
258
+ use_description=q[1]["use_description"],
259
+ metric_type=dynamic_metric_type, # Use dynamic or config metric type
260
+ loader=self.loader, # Pass the loader instance
261
+ ).extract_subgraph(q[1]["desc_emb"], q[1]["feat_emb"], q[1]["node_type"], cfg_db)
241
262
 
242
263
  # Append the extracted subgraph to the dictionary
243
264
  unified_subgraph["nodes"].append(subgraph["nodes"].tolist())
244
265
  unified_subgraph["edges"].append(subgraph["edges"].tolist())
245
- subgraphs.append((q[1]['node_name'],
246
- subgraph["nodes"].tolist(),
247
- subgraph["edges"].tolist()))
266
+ subgraphs.append(
267
+ (
268
+ q[1]["node_name"],
269
+ subgraph["nodes"].tolist(),
270
+ subgraph["edges"].tolist(),
271
+ )
272
+ )
248
273
 
249
274
  # end = datetime.datetime.now()
250
275
  # logger.log(logging.INFO, "Subgraph extraction time: %s seconds",
251
276
  # (end - start).total_seconds())
252
277
 
253
278
  # Concatenate and get unique node and edge indices
254
- unified_subgraph["nodes"] = py.unique(
255
- py.concatenate([py.array(list_) for list_ in unified_subgraph["nodes"]])
279
+ nodes_arrays = [self.loader.py.array(list_) for list_ in unified_subgraph["nodes"]]
280
+ unified_subgraph["nodes"] = self.loader.py.unique(
281
+ self.loader.py.concatenate(nodes_arrays)
256
282
  ).tolist()
257
- unified_subgraph["edges"] = py.unique(
258
- py.concatenate([py.array(list_) for list_ in unified_subgraph["edges"]])
283
+ edges_arrays = [self.loader.py.array(list_) for list_ in unified_subgraph["edges"]]
284
+ unified_subgraph["edges"] = self.loader.py.unique(
285
+ self.loader.py.concatenate(edges_arrays)
259
286
  ).tolist()
260
287
 
261
- # Convert the unified subgraph and subgraphs to cudf DataFrames
262
- unified_subgraph = df.DataFrame([("Unified Subgraph",
263
- unified_subgraph["nodes"],
264
- unified_subgraph["edges"])],
265
- columns=["name", "nodes", "edges"])
266
- subgraphs = df.DataFrame(subgraphs, columns=["name", "nodes", "edges"])
288
+ # Convert the unified subgraph and subgraphs to DataFrames
289
+ unified_subgraph = self.loader.df.DataFrame(
290
+ [("Unified Subgraph", unified_subgraph["nodes"], unified_subgraph["edges"])],
291
+ columns=["name", "nodes", "edges"],
292
+ )
293
+ subgraphs = self.loader.df.DataFrame(subgraphs, columns=["name", "nodes", "edges"])
267
294
 
268
- # Concate both DataFrames
269
- subgraphs = df.concat([unified_subgraph, subgraphs], ignore_index=True)
295
+ # Concatenate both DataFrames
296
+ subgraphs = self.loader.df.concat([unified_subgraph, subgraphs], ignore_index=True)
270
297
 
271
298
  return subgraphs
272
299
 
273
- def _prepare_final_subgraph(self,
274
- state:Annotated[dict, InjectedState],
275
- subgraph: dict,
276
- cfg: dict,
277
- cfg_db) -> dict:
300
+ def _prepare_final_subgraph(
301
+ self, state: Annotated[dict, InjectedState], subgraph: dict, cfg: dict, cfg_db
302
+ ) -> dict:
278
303
  """
279
304
  Prepare the subgraph based on the extracted subgraph.
280
305
 
@@ -288,94 +313,110 @@ class MultimodalSubgraphExtractionTool(BaseTool):
288
313
  Returns:
289
314
  A dictionary containing the PyG graph, NetworkX graph, and textualized graph.
290
315
  """
291
- # Convert the dict to a cudf DataFrame
292
- node_colors = {n: cfg.node_colors_dict[k]
293
- for k, v in state["selections"].items() for n in v}
294
- color_df = df.DataFrame(list(node_colors.items()), columns=["node_id", "color"])
316
+ # Convert the dict to a DataFrame
317
+ node_colors = {
318
+ n: cfg.node_colors_dict[k] for k, v in state["selections"].items() for n in v
319
+ }
320
+ color_df = self.loader.df.DataFrame(list(node_colors.items()), columns=["node_id", "color"])
295
321
  # print(color_df)
296
322
 
297
323
  # Prepare the subgraph dictionary
298
- graph_dict = {
299
- "name": [],
300
- "nodes": [],
301
- "edges": [],
302
- "text": ""
303
- }
324
+ graph_dict = {"name": [], "nodes": [], "edges": [], "text": ""}
304
325
  for sub in getattr(subgraph, "to_pandas", lambda: subgraph)().itertuples(index=False):
305
- # Prepare the graph name
306
- print(f"Processing subgraph: {sub.name}")
307
- print('---')
308
- print(sub.nodes)
309
- print('---')
310
- print(sub.edges)
311
- print('---')
312
-
313
- # Prepare graph dataframes
314
- # Nodes
315
- coll_name = f"{cfg_db.milvus_db.database_name}_nodes"
316
- node_coll = Collection(name=coll_name)
317
- node_coll.load()
318
- graph_nodes = node_coll.query(
319
- expr=f'node_index IN [{",".join(f"{n}" for n in sub.nodes)}]',
320
- output_fields=['node_id', 'node_name', 'node_type', 'desc']
321
- )
322
- graph_nodes = df.DataFrame(graph_nodes)
323
- graph_nodes.drop(columns=['node_index'], inplace=True)
324
- if not color_df.empty:
325
- # Merge the color dataframe with the graph nodes
326
- graph_nodes = graph_nodes.merge(color_df, on="node_id", how="left")
327
- else:
328
- graph_nodes["color"] = 'black' # Default color
329
- graph_nodes['color'].fillna('black', inplace=True) # Fill NaN colors with black
330
- # Edges
331
- coll_name = f"{cfg_db.milvus_db.database_name}_edges"
332
- edge_coll = Collection(name=coll_name)
333
- edge_coll.load()
334
- graph_edges = edge_coll.query(
335
- expr=f'triplet_index IN [{",".join(f"{e}" for e in sub.edges)}]',
336
- output_fields=['head_id', 'tail_id', 'edge_type']
337
- )
338
- graph_edges = df.DataFrame(graph_edges)
339
- graph_edges.drop(columns=['triplet_index'], inplace=True)
340
- graph_edges['edge_type'] = graph_edges['edge_type'].str.split('|')
326
+ graph_nodes, graph_edges = self._process_subgraph_data(sub, cfg_db, color_df)
341
327
 
342
328
  # Prepare lists for visualization
343
329
  graph_dict["name"].append(sub.name)
344
- graph_dict["nodes"].append([(
345
- row.node_id,
346
- {'hover': "Node Name : " + row.node_name + "\n" +\
347
- "Node Type : " + row.node_type + "\n" +
348
- "Desc : " + row.desc,
349
- 'click': '$hover',
350
- 'color': row.color})
351
- for row in getattr(graph_nodes,
352
- "to_pandas",
353
- lambda: graph_nodes)().itertuples(index=False)])
354
- graph_dict["edges"].append([(
355
- row.head_id,
356
- row.tail_id,
357
- {'label': tuple(row.edge_type)})
358
- for row in getattr(graph_edges,
359
- "to_pandas",
360
- lambda: graph_edges)().itertuples(index=False)])
330
+ graph_dict["nodes"].append(
331
+ [
332
+ (
333
+ row.node_id,
334
+ {
335
+ "hover": "Node Name : "
336
+ + row.node_name
337
+ + "\n"
338
+ + "Node Type : "
339
+ + row.node_type
340
+ + "\n"
341
+ + "Desc : "
342
+ + row.desc,
343
+ "click": "$hover",
344
+ "color": row.color,
345
+ },
346
+ )
347
+ for row in getattr(
348
+ graph_nodes, "to_pandas", lambda graph_nodes=graph_nodes: graph_nodes
349
+ )().itertuples(index=False)
350
+ ]
351
+ )
352
+ graph_dict["edges"].append(
353
+ [
354
+ (row.head_id, row.tail_id, {"label": tuple(row.edge_type)})
355
+ for row in getattr(
356
+ graph_edges, "to_pandas", lambda graph_edges=graph_edges: graph_edges
357
+ )().itertuples(index=False)
358
+ ]
359
+ )
361
360
 
362
361
  # Prepare the textualized subgraph
363
362
  if sub.name == "Unified Subgraph":
364
- graph_nodes = graph_nodes[['node_id', 'desc']]
365
- graph_nodes.rename(columns={'desc': 'node_attr'}, inplace=True)
366
- graph_edges = graph_edges[['head_id', 'edge_type', 'tail_id']]
367
- graph_dict["text"] = (
368
- getattr(graph_nodes, "to_pandas", lambda: graph_nodes)().to_csv(index=False)
369
- + "\n"
370
- + getattr(graph_edges, "to_pandas", lambda: graph_edges)().to_csv(index=False)
371
- )
363
+ graph_nodes = graph_nodes[["node_id", "desc"]]
364
+ graph_nodes.rename(columns={"desc": "node_attr"}, inplace=True)
365
+ graph_edges = graph_edges[["head_id", "edge_type", "tail_id"]]
366
+ nodes_pandas = getattr(
367
+ graph_nodes, "to_pandas", lambda graph_nodes=graph_nodes: graph_nodes
368
+ )()
369
+ nodes_csv = nodes_pandas.to_csv(index=False)
370
+ edges_pandas = getattr(
371
+ graph_edges, "to_pandas", lambda graph_edges=graph_edges: graph_edges
372
+ )()
373
+ edges_csv = edges_pandas.to_csv(index=False)
374
+ graph_dict["text"] = nodes_csv + "\n" + edges_csv
372
375
 
373
376
  return graph_dict
374
377
 
375
- def normalize_vector(self,
376
- v : list) -> list:
378
+ def _process_subgraph_data(self, sub, cfg_db, color_df):
379
+ """Helper method to process individual subgraph data."""
380
+ print(f"Processing subgraph: {sub.name}")
381
+ print("---")
382
+ print(sub.nodes)
383
+ print("---")
384
+ print(sub.edges)
385
+ print("---")
386
+
387
+ # Prepare graph dataframes - Nodes
388
+ coll_name = f"{cfg_db.milvus_db.database_name}_nodes"
389
+ node_coll = Collection(name=coll_name)
390
+ node_coll.load()
391
+ graph_nodes = node_coll.query(
392
+ expr=f"node_index IN [{','.join(f'{n}' for n in sub.nodes)}]",
393
+ output_fields=["node_id", "node_name", "node_type", "desc"],
394
+ )
395
+ graph_nodes = self.loader.df.DataFrame(graph_nodes)
396
+ graph_nodes.drop(columns=["node_index"], inplace=True)
397
+ if not color_df.empty:
398
+ graph_nodes = graph_nodes.merge(color_df, on="node_id", how="left")
399
+ else:
400
+ graph_nodes["color"] = "black"
401
+ graph_nodes["color"] = graph_nodes["color"].fillna("black")
402
+
403
+ # Edges
404
+ coll_name = f"{cfg_db.milvus_db.database_name}_edges"
405
+ edge_coll = Collection(name=coll_name)
406
+ edge_coll.load()
407
+ graph_edges = edge_coll.query(
408
+ expr=f"triplet_index IN [{','.join(f'{e}' for e in sub.edges)}]",
409
+ output_fields=["head_id", "tail_id", "edge_type"],
410
+ )
411
+ graph_edges = self.loader.df.DataFrame(graph_edges)
412
+ graph_edges.drop(columns=["triplet_index"], inplace=True)
413
+ graph_edges["edge_type"] = graph_edges["edge_type"].str.split("|")
414
+
415
+ return graph_nodes, graph_edges
416
+
417
+ def normalize_vector(self, v: list) -> list:
377
418
  """
378
- Normalize a vector using CuPy.
419
+ Normalize a vector using appropriate library (CuPy for GPU, NumPy for CPU).
379
420
 
380
421
  Args:
381
422
  v : Vector to normalize.
@@ -383,9 +424,13 @@ class MultimodalSubgraphExtractionTool(BaseTool):
383
424
  Returns:
384
425
  Normalized vector.
385
426
  """
386
- v = py.asarray(v)
387
- norm = py.linalg.norm(v)
388
- return (v / norm).tolist()
427
+ if self.loader.normalize_vectors:
428
+ # GPU mode: normalize the vector
429
+ v_array = self.loader.py.asarray(v)
430
+ norm = self.loader.py.linalg.norm(v_array)
431
+ return (v_array / norm).tolist()
432
+ # CPU mode: return as-is for COSINE similarity
433
+ return v
389
434
 
390
435
  def _run(
391
436
  self,
@@ -411,7 +456,8 @@ class MultimodalSubgraphExtractionTool(BaseTool):
411
456
  # Load hydra configuration
412
457
  with hydra.initialize(version_base=None, config_path="../configs"):
413
458
  cfg = hydra.compose(
414
- config_name="config", overrides=["tools/multimodal_subgraph_extraction=default"]
459
+ config_name="config",
460
+ overrides=["tools/multimodal_subgraph_extraction=default"],
415
461
  )
416
462
  cfg_db = cfg.app.frontend
417
463
  cfg = cfg.tools.multimodal_subgraph_extraction
@@ -431,10 +477,9 @@ class MultimodalSubgraphExtractionTool(BaseTool):
431
477
  logger.log(logging.INFO, "_prepare_query_modalities")
432
478
  # start = datetime.datetime.now()
433
479
  query_df = self._prepare_query_modalities(
434
- {"text": prompt,
435
- "emb": [self.normalize_vector(
436
- state["embedding_model"].embed_query(prompt)
437
- )]
480
+ {
481
+ "text": prompt,
482
+ "emb": [self.normalize_vector(state["embedding_model"].embed_query(prompt))],
438
483
  },
439
484
  state,
440
485
  cfg_db,
@@ -446,10 +491,7 @@ class MultimodalSubgraphExtractionTool(BaseTool):
446
491
  # Perform subgraph extraction
447
492
  logger.log(logging.INFO, "_perform_subgraph_extraction")
448
493
  # start = datetime.datetime.now()
449
- subgraphs = self._perform_subgraph_extraction(state,
450
- cfg,
451
- cfg_db,
452
- query_df)
494
+ subgraphs = self._perform_subgraph_extraction(state, cfg, cfg_db, query_df)
453
495
  # end = datetime.datetime.now()
454
496
  # logger.log(logging.INFO, "_perform_subgraph_extraction time: %s seconds",
455
497
  # (end - start).total_seconds())
@@ -458,10 +500,7 @@ class MultimodalSubgraphExtractionTool(BaseTool):
458
500
  logger.log(logging.INFO, "_prepare_final_subgraph")
459
501
  logger.log(logging.INFO, "Subgraphs extracted: %s", len(subgraphs))
460
502
  # start = datetime.datetime.now()
461
- final_subgraph = self._prepare_final_subgraph(state,
462
- subgraphs,
463
- cfg,
464
- cfg_db)
503
+ final_subgraph = self._prepare_final_subgraph(state, subgraphs, cfg, cfg_db)
465
504
  # end = datetime.datetime.now()
466
505
  # logger.log(logging.INFO, "_prepare_final_subgraph time: %s seconds",
467
506
  # (end - start).total_seconds())
@@ -497,7 +536,8 @@ class MultimodalSubgraphExtractionTool(BaseTool):
497
536
 
498
537
  # Return the updated state of the tool
499
538
  return Command(
500
- update=dic_updated_state_for_model | {
539
+ update=dic_updated_state_for_model
540
+ | {
501
541
  # update the message history
502
542
  "messages": [
503
543
  ToolMessage(