tooluniverse 0.2.0__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tooluniverse might be problematic. Click here for more details.

Files changed (186) hide show
  1. tooluniverse/__init__.py +340 -4
  2. tooluniverse/admetai_tool.py +84 -0
  3. tooluniverse/agentic_tool.py +563 -0
  4. tooluniverse/alphafold_tool.py +96 -0
  5. tooluniverse/base_tool.py +129 -6
  6. tooluniverse/boltz_tool.py +207 -0
  7. tooluniverse/chem_tool.py +192 -0
  8. tooluniverse/compose_scripts/__init__.py +1 -0
  9. tooluniverse/compose_scripts/biomarker_discovery.py +293 -0
  10. tooluniverse/compose_scripts/comprehensive_drug_discovery.py +186 -0
  11. tooluniverse/compose_scripts/drug_safety_analyzer.py +89 -0
  12. tooluniverse/compose_scripts/literature_tool.py +34 -0
  13. tooluniverse/compose_scripts/output_summarizer.py +279 -0
  14. tooluniverse/compose_scripts/tool_description_optimizer.py +681 -0
  15. tooluniverse/compose_scripts/tool_discover.py +705 -0
  16. tooluniverse/compose_scripts/tool_graph_composer.py +448 -0
  17. tooluniverse/compose_tool.py +371 -0
  18. tooluniverse/ctg_tool.py +1002 -0
  19. tooluniverse/custom_tool.py +81 -0
  20. tooluniverse/dailymed_tool.py +108 -0
  21. tooluniverse/data/admetai_tools.json +155 -0
  22. tooluniverse/data/adverse_event_tools.json +108 -0
  23. tooluniverse/data/agentic_tools.json +1156 -0
  24. tooluniverse/data/alphafold_tools.json +87 -0
  25. tooluniverse/data/boltz_tools.json +9 -0
  26. tooluniverse/data/chembl_tools.json +16 -0
  27. tooluniverse/data/clinicaltrials_gov_tools.json +326 -0
  28. tooluniverse/data/compose_tools.json +202 -0
  29. tooluniverse/data/dailymed_tools.json +70 -0
  30. tooluniverse/data/dataset_tools.json +646 -0
  31. tooluniverse/data/disease_target_score_tools.json +712 -0
  32. tooluniverse/data/efo_tools.json +17 -0
  33. tooluniverse/data/embedding_tools.json +319 -0
  34. tooluniverse/data/enrichr_tools.json +31 -0
  35. tooluniverse/data/europe_pmc_tools.json +22 -0
  36. tooluniverse/data/expert_feedback_tools.json +10 -0
  37. tooluniverse/data/fda_drug_adverse_event_tools.json +491 -0
  38. tooluniverse/data/fda_drug_labeling_tools.json +1 -1
  39. tooluniverse/data/fda_drugs_with_brand_generic_names_for_tool.py +76929 -148860
  40. tooluniverse/data/finder_tools.json +209 -0
  41. tooluniverse/data/gene_ontology_tools.json +113 -0
  42. tooluniverse/data/gwas_tools.json +1082 -0
  43. tooluniverse/data/hpa_tools.json +333 -0
  44. tooluniverse/data/humanbase_tools.json +47 -0
  45. tooluniverse/data/idmap_tools.json +74 -0
  46. tooluniverse/data/mcp_client_tools_example.json +113 -0
  47. tooluniverse/data/mcpautoloadertool_defaults.json +28 -0
  48. tooluniverse/data/medlineplus_tools.json +141 -0
  49. tooluniverse/data/monarch_tools.json +1 -1
  50. tooluniverse/data/openalex_tools.json +36 -0
  51. tooluniverse/data/opentarget_tools.json +1 -1
  52. tooluniverse/data/output_summarization_tools.json +101 -0
  53. tooluniverse/data/packages/bioinformatics_core_tools.json +1756 -0
  54. tooluniverse/data/packages/categorized_tools.txt +206 -0
  55. tooluniverse/data/packages/cheminformatics_tools.json +347 -0
  56. tooluniverse/data/packages/earth_sciences_tools.json +74 -0
  57. tooluniverse/data/packages/genomics_tools.json +776 -0
  58. tooluniverse/data/packages/image_processing_tools.json +38 -0
  59. tooluniverse/data/packages/machine_learning_tools.json +789 -0
  60. tooluniverse/data/packages/neuroscience_tools.json +62 -0
  61. tooluniverse/data/packages/original_tools.txt +0 -0
  62. tooluniverse/data/packages/physics_astronomy_tools.json +62 -0
  63. tooluniverse/data/packages/scientific_computing_tools.json +560 -0
  64. tooluniverse/data/packages/single_cell_tools.json +453 -0
  65. tooluniverse/data/packages/structural_biology_tools.json +396 -0
  66. tooluniverse/data/packages/visualization_tools.json +399 -0
  67. tooluniverse/data/pubchem_tools.json +215 -0
  68. tooluniverse/data/pubtator_tools.json +68 -0
  69. tooluniverse/data/rcsb_pdb_tools.json +1332 -0
  70. tooluniverse/data/reactome_tools.json +19 -0
  71. tooluniverse/data/semantic_scholar_tools.json +26 -0
  72. tooluniverse/data/special_tools.json +2 -25
  73. tooluniverse/data/tool_composition_tools.json +88 -0
  74. tooluniverse/data/toolfinderkeyword_defaults.json +34 -0
  75. tooluniverse/data/txagent_client_tools.json +9 -0
  76. tooluniverse/data/uniprot_tools.json +211 -0
  77. tooluniverse/data/url_fetch_tools.json +94 -0
  78. tooluniverse/data/uspto_downloader_tools.json +9 -0
  79. tooluniverse/data/uspto_tools.json +811 -0
  80. tooluniverse/data/xml_tools.json +3275 -0
  81. tooluniverse/dataset_tool.py +296 -0
  82. tooluniverse/default_config.py +165 -0
  83. tooluniverse/efo_tool.py +42 -0
  84. tooluniverse/embedding_database.py +630 -0
  85. tooluniverse/embedding_sync.py +396 -0
  86. tooluniverse/enrichr_tool.py +266 -0
  87. tooluniverse/europe_pmc_tool.py +52 -0
  88. tooluniverse/execute_function.py +1775 -95
  89. tooluniverse/extended_hooks.py +444 -0
  90. tooluniverse/gene_ontology_tool.py +194 -0
  91. tooluniverse/graphql_tool.py +158 -36
  92. tooluniverse/gwas_tool.py +358 -0
  93. tooluniverse/hpa_tool.py +1645 -0
  94. tooluniverse/humanbase_tool.py +389 -0
  95. tooluniverse/logging_config.py +254 -0
  96. tooluniverse/mcp_client_tool.py +764 -0
  97. tooluniverse/mcp_integration.py +413 -0
  98. tooluniverse/mcp_tool_registry.py +925 -0
  99. tooluniverse/medlineplus_tool.py +337 -0
  100. tooluniverse/openalex_tool.py +228 -0
  101. tooluniverse/openfda_adv_tool.py +283 -0
  102. tooluniverse/openfda_tool.py +393 -160
  103. tooluniverse/output_hook.py +1122 -0
  104. tooluniverse/package_tool.py +195 -0
  105. tooluniverse/pubchem_tool.py +158 -0
  106. tooluniverse/pubtator_tool.py +168 -0
  107. tooluniverse/rcsb_pdb_tool.py +38 -0
  108. tooluniverse/reactome_tool.py +108 -0
  109. tooluniverse/remote/boltz/boltz_mcp_server.py +50 -0
  110. tooluniverse/remote/depmap_24q2/depmap_24q2_mcp_tool.py +442 -0
  111. tooluniverse/remote/expert_feedback/human_expert_mcp_tools.py +2013 -0
  112. tooluniverse/remote/expert_feedback/simple_test.py +23 -0
  113. tooluniverse/remote/expert_feedback/start_web_interface.py +188 -0
  114. tooluniverse/remote/expert_feedback/web_only_interface.py +0 -0
  115. tooluniverse/remote/immune_compass/compass_tool.py +327 -0
  116. tooluniverse/remote/pinnacle/pinnacle_tool.py +328 -0
  117. tooluniverse/remote/transcriptformer/transcriptformer_tool.py +586 -0
  118. tooluniverse/remote/uspto_downloader/uspto_downloader_mcp_server.py +61 -0
  119. tooluniverse/remote/uspto_downloader/uspto_downloader_tool.py +120 -0
  120. tooluniverse/remote_tool.py +99 -0
  121. tooluniverse/restful_tool.py +53 -30
  122. tooluniverse/scripts/generate_tool_graph.py +408 -0
  123. tooluniverse/scripts/visualize_tool_graph.py +829 -0
  124. tooluniverse/semantic_scholar_tool.py +62 -0
  125. tooluniverse/smcp.py +2452 -0
  126. tooluniverse/smcp_server.py +975 -0
  127. tooluniverse/test/mcp_server_test.py +0 -0
  128. tooluniverse/test/test_admetai_tool.py +370 -0
  129. tooluniverse/test/test_agentic_tool.py +129 -0
  130. tooluniverse/test/test_alphafold_tool.py +71 -0
  131. tooluniverse/test/test_chem_tool.py +37 -0
  132. tooluniverse/test/test_compose_lieraturereview.py +63 -0
  133. tooluniverse/test/test_compose_tool.py +448 -0
  134. tooluniverse/test/test_dailymed.py +69 -0
  135. tooluniverse/test/test_dataset_tool.py +200 -0
  136. tooluniverse/test/test_disease_target_score.py +56 -0
  137. tooluniverse/test/test_drugbank_filter_examples.py +179 -0
  138. tooluniverse/test/test_efo.py +31 -0
  139. tooluniverse/test/test_enrichr_tool.py +21 -0
  140. tooluniverse/test/test_europe_pmc_tool.py +20 -0
  141. tooluniverse/test/test_fda_adv.py +95 -0
  142. tooluniverse/test/test_fda_drug_labeling.py +91 -0
  143. tooluniverse/test/test_gene_ontology_tools.py +66 -0
  144. tooluniverse/test/test_gwas_tool.py +139 -0
  145. tooluniverse/test/test_hpa.py +625 -0
  146. tooluniverse/test/test_humanbase_tool.py +20 -0
  147. tooluniverse/test/test_idmap_tools.py +61 -0
  148. tooluniverse/test/test_mcp_server.py +211 -0
  149. tooluniverse/test/test_mcp_tool.py +247 -0
  150. tooluniverse/test/test_medlineplus.py +220 -0
  151. tooluniverse/test/test_openalex_tool.py +32 -0
  152. tooluniverse/test/test_opentargets.py +28 -0
  153. tooluniverse/test/test_pubchem_tool.py +116 -0
  154. tooluniverse/test/test_pubtator_tool.py +37 -0
  155. tooluniverse/test/test_rcsb_pdb_tool.py +86 -0
  156. tooluniverse/test/test_reactome.py +54 -0
  157. tooluniverse/test/test_semantic_scholar_tool.py +24 -0
  158. tooluniverse/test/test_software_tools.py +147 -0
  159. tooluniverse/test/test_tool_description_optimizer.py +49 -0
  160. tooluniverse/test/test_tool_finder.py +26 -0
  161. tooluniverse/test/test_tool_finder_llm.py +252 -0
  162. tooluniverse/test/test_tools_find.py +195 -0
  163. tooluniverse/test/test_uniprot_tools.py +74 -0
  164. tooluniverse/test/test_uspto_tool.py +72 -0
  165. tooluniverse/test/test_xml_tool.py +113 -0
  166. tooluniverse/tool_finder_embedding.py +267 -0
  167. tooluniverse/tool_finder_keyword.py +693 -0
  168. tooluniverse/tool_finder_llm.py +699 -0
  169. tooluniverse/tool_graph_web_ui.py +955 -0
  170. tooluniverse/tool_registry.py +416 -0
  171. tooluniverse/uniprot_tool.py +155 -0
  172. tooluniverse/url_tool.py +253 -0
  173. tooluniverse/uspto_tool.py +240 -0
  174. tooluniverse/utils.py +369 -41
  175. tooluniverse/xml_tool.py +369 -0
  176. tooluniverse-1.0.1.dist-info/METADATA +387 -0
  177. tooluniverse-1.0.1.dist-info/RECORD +182 -0
  178. tooluniverse-1.0.1.dist-info/entry_points.txt +9 -0
  179. tooluniverse/generate_mcp_tools.py +0 -113
  180. tooluniverse/mcp_server.py +0 -3340
  181. tooluniverse-0.2.0.dist-info/METADATA +0 -139
  182. tooluniverse-0.2.0.dist-info/RECORD +0 -21
  183. tooluniverse-0.2.0.dist-info/entry_points.txt +0 -4
  184. {tooluniverse-0.2.0.dist-info → tooluniverse-1.0.1.dist-info}/WHEEL +0 -0
  185. {tooluniverse-0.2.0.dist-info → tooluniverse-1.0.1.dist-info}/licenses/LICENSE +0 -0
  186. {tooluniverse-0.2.0.dist-info → tooluniverse-1.0.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,74 @@
1
+ from tooluniverse import ToolUniverse
2
+ from typing import Any, Dict, List
3
+
4
+ tooluni = ToolUniverse()
5
+ tooluni.load_tools()
6
+
7
+ TEST_ACC = "P05067" # A4_HUMAN
8
+
9
+ test_queries: List[Dict[str, Any]] = [
10
+ {"name": "UniProt_get_entry_by_accession", "arguments": {"accession": TEST_ACC}},
11
+ {"name": "UniProt_get_function_by_accession", "arguments": {"accession": TEST_ACC}},
12
+ {
13
+ "name": "UniProt_get_recommended_name_by_accession",
14
+ "arguments": {"accession": TEST_ACC},
15
+ },
16
+ {
17
+ "name": "UniProt_get_alternative_names_by_accession",
18
+ "arguments": {"accession": TEST_ACC},
19
+ },
20
+ {"name": "UniProt_get_organism_by_accession", "arguments": {"accession": TEST_ACC}},
21
+ {
22
+ "name": "UniProt_get_subcellular_location_by_accession",
23
+ "arguments": {"accession": TEST_ACC},
24
+ },
25
+ {
26
+ "name": "UniProt_get_disease_variants_by_accession",
27
+ "arguments": {"accession": TEST_ACC},
28
+ },
29
+ {
30
+ "name": "UniProt_get_ptm_processing_by_accession",
31
+ "arguments": {"accession": TEST_ACC},
32
+ },
33
+ {"name": "UniProt_get_sequence_by_accession", "arguments": {"accession": TEST_ACC}},
34
+ {
35
+ "name": "UniProt_get_isoform_ids_by_accession",
36
+ "arguments": {"accession": TEST_ACC},
37
+ },
38
+ ]
39
+
40
+
41
+ def format_value(value, max_items=5, max_length=200):
42
+ """Helper function to format output values with more detail"""
43
+ if isinstance(value, dict):
44
+ dict_str = str(value)
45
+ return f"Dict ({len(dict_str)} chars): {dict_str[:500]}{'...' if len(dict_str) > 500 else ''}"
46
+ elif isinstance(value, list):
47
+ if not value:
48
+ return "Empty list"
49
+ items_to_show = value[:max_items]
50
+ items_str = "\n - ".join(
51
+ [
52
+ str(item)[:max_length] + ("..." if len(str(item)) > max_length else "")
53
+ for item in items_to_show
54
+ ]
55
+ )
56
+ remaining = len(value) - max_items
57
+ return f"List with {len(value)} items:\n - {items_str}" + (
58
+ f"\n ... and {remaining} more items" if remaining > 0 else ""
59
+ )
60
+ elif isinstance(value, str):
61
+ return f"String ({len(value)} chars): {value[:max_length]}{'...' if len(value) > max_length else ''}"
62
+ else:
63
+ return f"Type: {type(value)}, Value: {value}"
64
+
65
+
66
+ for idx, q in enumerate(test_queries, 1):
67
+ print(f"\n{'='*80}\n[{idx}] {q['name']}({q['arguments']['accession']})")
68
+ res = tooluni.run(q)
69
+
70
+ if isinstance(res, dict) and "error" in res:
71
+ print(f"ERROR: {res['error']}")
72
+ else:
73
+ print(format_value(res))
74
+ print()
@@ -0,0 +1,72 @@
1
+ from tooluniverse import ToolUniverse
2
+
3
+ tooluni = ToolUniverse()
4
+ tooluni.load_tools()
5
+
6
+ # All test cases compiled into the specified list format
7
+ test_queries = [
8
+ # Test Case: get_patent_overview_by_text_query example
9
+ {
10
+ "name": "get_patent_overview_by_text_query",
11
+ "arguments": {
12
+ "query": "iron oxide",
13
+ "exact_match": True,
14
+ "sort": "filingDate desc",
15
+ "limit": 5,
16
+ "rangeFilters": "filingDate 2021-01-01:2024-02-01",
17
+ },
18
+ },
19
+ # Test Case: get_patent_overview_by_text_query example
20
+ {
21
+ "name": "get_patent_overview_by_text_query",
22
+ "arguments": {
23
+ "query": "machine learning",
24
+ "exact_match": False,
25
+ "sort": "filingDate desc",
26
+ "limit": 1,
27
+ "offset": 53,
28
+ "rangeFilters": "filingDate 2021-01-01:2024-02-01",
29
+ },
30
+ },
31
+ # Test Case: get_patent_application_metadata
32
+ {
33
+ "name": "get_patent_application_metadata",
34
+ "arguments": {"applicationNumberText": "19053071"},
35
+ },
36
+ # Test Case: get_patent_term_adjustment_data
37
+ {
38
+ "name": "get_patent_term_adjustment_data",
39
+ "arguments": {"applicationNumberText": "16232347"},
40
+ },
41
+ # Test Case: get_patent_term_adjustment_data
42
+ {
43
+ "name": "get_patent_term_adjustment_data",
44
+ "arguments": {"applicationNumberText": "17783167"},
45
+ },
46
+ # Test Case: get_patent_continuity_data
47
+ {
48
+ "name": "get_patent_continuity_data",
49
+ "arguments": {"applicationNumberText": "19053071"},
50
+ },
51
+ # Test Case: get_patient_foreign_priority_data example
52
+ {
53
+ "name": "get_patent_foreign_priority_data",
54
+ "arguments": {"applicationNumberText": "19053071"},
55
+ },
56
+ # Test Case: get_associated_documents_metadata
57
+ {
58
+ "name": "get_associated_documents_metadata",
59
+ "arguments": {"applicationNumberText": "16232347"},
60
+ },
61
+ ]
62
+
63
+ test_queries = test_queries # Repeat the test cases three times for thorough testing
64
+
65
+ for idx, query in enumerate(test_queries):
66
+ print(
67
+ f"\n[{idx+1}] Running tool: {query['name']} with arguments: {query['arguments']}"
68
+ )
69
+ result = tooluni.run(query)
70
+ print("✅ Success.")
71
+ result_str = str(result)
72
+ print(f"📊 Result: {result_str}")
@@ -0,0 +1,113 @@
1
+ from tooluniverse import ToolUniverse
2
+ import json
3
+
4
+ # Step 1: Initialize tool universe
5
+ tooluni = ToolUniverse()
6
+ tooluni.load_tools()
7
+
8
+ # Test queries for XML tools using MedlinePlus health topics data
9
+ test_queries = [
10
+ {
11
+ "name": "mesh_get_subjects_by_pharmacological_action",
12
+ "arguments": {"query": "calcium", "limit": 10},
13
+ },
14
+ {
15
+ "name": "mesh_get_subjects_by_subject_scope_or_definition",
16
+ "arguments": {"query": "glycan", "limit": 2},
17
+ },
18
+ {
19
+ "name": "mesh_get_subjects_by_subject_name",
20
+ "arguments": {
21
+ "query": "antibody",
22
+ "limit": 10,
23
+ },
24
+ },
25
+ {
26
+ "name": "mesh_get_subjects_by_subject_id",
27
+ "arguments": {
28
+ "query": "D007306",
29
+ "limit": 5,
30
+ },
31
+ },
32
+ {
33
+ "name": "drugbank_get_drug_basic_info_by_drug_name_or_drugbank_id",
34
+ "arguments": {"query": "lovastatin", "limit": 2},
35
+ },
36
+ {
37
+ "name": "drugbank_get_indications_by_drug_name_or_drugbank_id",
38
+ "arguments": {"query": "DB00945", "limit": 5},
39
+ },
40
+ {
41
+ "name": "drugbank_get_drug_name_and_description_by_indication",
42
+ "arguments": {"query": "hypertension", "limit": 1},
43
+ },
44
+ {
45
+ "name": "drugbank_get_pharmacology_by_drug_name_or_drugbank_id",
46
+ "arguments": {"query": "lovastatin", "limit": 1},
47
+ },
48
+ {
49
+ "name": "drugbank_get_pharmacology_by_drug_name_or_drugbank_id",
50
+ "arguments": {"query": "simvastatin", "limit": 1},
51
+ },
52
+ {
53
+ "name": "drugbank_get_drug_name_description_pharmacology_by_mechanism_of_action",
54
+ "arguments": {"query": "receptor antagonist", "limit": 1},
55
+ },
56
+ {
57
+ "name": "drugbank_get_drug_interactions_by_drug_name_or_drugbank_id",
58
+ "arguments": {"query": "carbidopa", "limit": 1},
59
+ },
60
+ {
61
+ "name": "drugbank_get_targets_by_drug_name_or_drugbank_id",
62
+ "arguments": {"query": "aspirin", "limit": 1},
63
+ },
64
+ {
65
+ "name": "drugbank_get_drug_name_and_description_by_target_name",
66
+ "arguments": {"query": "dopamine receptor", "limit": 1},
67
+ },
68
+ {
69
+ "name": "drugbank_get_drug_products_by_name_or_drugbank_id",
70
+ "arguments": {"query": "ibuprofen", "limit": 1},
71
+ },
72
+ {
73
+ "name": "drugbank_get_safety_by_drug_name_or_drugbank_id",
74
+ "arguments": {"query": "lovastatin", "limit": 2},
75
+ },
76
+ {
77
+ "name": "drugbank_get_drug_chemistry_by_drug_name_or_drugbank_id",
78
+ "arguments": {"query": "caffeine", "limit": 1},
79
+ },
80
+ {
81
+ "name": "drugbank_get_drug_references_by_drug_name_or_drugbank_id",
82
+ "arguments": {"query": "aspirin", "limit": 1},
83
+ },
84
+ {
85
+ "name": "drugbank_get_drug_pathways_and_reactions_by_drug_name_or_drugbank_id",
86
+ "arguments": {"query": "glucose", "limit": 1},
87
+ },
88
+ {
89
+ "name": "drugbank_get_drug_name_and_description_by_pathway_name",
90
+ "arguments": {"query": "glycolysis", "limit": 1},
91
+ },
92
+ {
93
+ "name": "drugbank_filter_drugs_by_name",
94
+ "arguments": {
95
+ "condition": "ends_with",
96
+ "value": "cillin", # Example: find drugs whose names end with 'cillin', pencillin antibiotics
97
+ "limit": 1,
98
+ },
99
+ },
100
+ ]
101
+
102
+ test_queries = test_queries
103
+
104
+ # Run all test queries
105
+ for idx, query in enumerate(test_queries):
106
+ print(f"\n[{idx+1}] Running tool: {query['name']}")
107
+ print(f"Arguments: {query['arguments']}")
108
+ print("-" * 60)
109
+
110
+ # try:
111
+ result = tooluni.run(query)
112
+ print("✅ Success!")
113
+ print(json.dumps(result, indent=2, ensure_ascii=False))
@@ -0,0 +1,267 @@
1
+ from sentence_transformers import SentenceTransformer
2
+ import torch
3
+ import json
4
+ import gc
5
+ from .utils import get_md5
6
+ from .base_tool import BaseTool
7
+ from .tool_registry import register_tool
8
+
9
+
10
+ @register_tool("ToolFinderEmbedding")
11
+ class ToolFinderEmbedding(BaseTool):
12
+ """
13
+ A tool finder model that uses RAG (Retrieval-Augmented Generation) to find relevant tools
14
+ based on user queries using semantic similarity search.
15
+
16
+ This class leverages sentence transformers to encode tool descriptions and find the most
17
+ relevant tools for a given query through embedding-based similarity matching.
18
+
19
+ Attributes:
20
+ rag_model_name (str): Name of the sentence transformer model for embeddings
21
+ rag_model (SentenceTransformer): The loaded sentence transformer model
22
+ tool_desc_embedding (torch.Tensor): Cached embeddings of tool descriptions
23
+ tool_name (list): List of available tool names
24
+ tool_embedding_path (str): Path to cached tool embeddings file
25
+ special_tools_name (list): List of special tools to exclude from results
26
+ tooluniverse: Reference to the tool universe containing all tools
27
+ """
28
+
29
+ def __init__(self, tool_config, tooluniverse):
30
+ """
31
+ Initialize the ToolFinderEmbedding with configuration and RAG model.
32
+
33
+ Args:
34
+ tool_config (dict): Configuration dictionary for the tool
35
+ """
36
+ super().__init__(tool_config)
37
+ self.rag_model = None
38
+ self.tool_desc_embedding = None
39
+ self.tool_name = None
40
+ self.tool_embedding_path = None
41
+ toolfinder_model = tool_config["configs"].get("tool_finder_model")
42
+ self.toolfinder_model = toolfinder_model
43
+ # Get exclude tools from config, with fallback to default list
44
+ self.exclude_tools = tool_config.get(
45
+ "exclude_tools",
46
+ tool_config.get("configs", {}).get(
47
+ "exclude_tools", ["Tool_RAG", "Tool_Finder", "Finish", "CallAgent"]
48
+ ),
49
+ )
50
+ self.load_rag_model()
51
+ print(
52
+ f"Using toolfinder model: {toolfinder_model}, GPU is required for this model for fast speed..."
53
+ )
54
+ self.load_tool_desc_embedding(tooluniverse, exclude_names=self.exclude_tools)
55
+
56
+ def load_rag_model(self):
57
+ """
58
+ Load the sentence transformer model for RAG-based tool retrieval.
59
+
60
+ Configures the model with appropriate sequence length and tokenizer settings
61
+ for optimal performance in tool description encoding.
62
+ """
63
+ self.rag_model = SentenceTransformer(self.toolfinder_model)
64
+ self.rag_model.max_seq_length = 4096
65
+ self.rag_model.tokenizer.padding_side = "right"
66
+
67
+ def load_tool_desc_embedding(
68
+ self,
69
+ tooluniverse,
70
+ include_names=None,
71
+ exclude_names=None,
72
+ include_categories=None,
73
+ exclude_categories=None,
74
+ ):
75
+ """
76
+ Load or generate embeddings for tool descriptions from the tool universe.
77
+
78
+ This method either loads cached embeddings from disk or generates new ones by encoding
79
+ all tool descriptions. Embeddings are cached to disk for faster subsequent loads.
80
+ Memory is properly cleaned up after embedding generation to avoid OOM issues.
81
+
82
+ Args:
83
+ tooluniverse: ToolUniverse instance containing all available tools
84
+ include_names (list, optional): Specific tool names to include
85
+ exclude_names (list, optional): Tool names to exclude
86
+ include_categories (list, optional): Tool categories to include
87
+ exclude_categories (list, optional): Tool categories to exclude
88
+ """
89
+ self.tooluniverse = tooluniverse
90
+ print("Loading tool descriptions and embeddings...")
91
+ self.tool_name, _ = tooluniverse.refresh_tool_name_desc(
92
+ enable_full_desc=True,
93
+ include_names=include_names,
94
+ exclude_names=exclude_names,
95
+ include_categories=include_categories,
96
+ exclude_categories=exclude_categories,
97
+ )
98
+
99
+ # Get filtered tools that match the tool_name list
100
+ filtered_tools = []
101
+ tool_name_set = set(self.tool_name)
102
+ for tool in tooluniverse.all_tools:
103
+ if tool["name"] in tool_name_set:
104
+ filtered_tools.append(tool)
105
+
106
+ all_tools_str = [
107
+ json.dumps(each)
108
+ for each in tooluniverse.prepare_tool_prompts(filtered_tools)
109
+ ]
110
+ md5_value = get_md5(str(all_tools_str))
111
+ print("get the md value of tools:", md5_value)
112
+ self.tool_embedding_path = (
113
+ self.toolfinder_model.split("/")[-1] + "tool_embedding_" + md5_value + ".pt"
114
+ )
115
+ try:
116
+ self.tool_desc_embedding = torch.load(
117
+ self.tool_embedding_path, weights_only=False
118
+ )
119
+ assert len(self.tool_desc_embedding) == len(
120
+ self.tool_name
121
+ ), "The number of tools in the tool_name list is not equal to the number of tool_desc_embedding."
122
+ print("\033[92mSuccessfully loaded cached embeddings.\033[0m")
123
+ except (RuntimeError, AssertionError, OSError):
124
+ self.tool_desc_embedding = None
125
+ print("\033[92mInferring the tool_desc_embedding.\033[0m")
126
+
127
+ # Generate embeddings
128
+ self.tool_desc_embedding = self.rag_model.encode(
129
+ all_tools_str, prompt="", normalize_embeddings=True
130
+ )
131
+
132
+ # Save embeddings to disk
133
+ torch.save(self.tool_desc_embedding, self.tool_embedding_path)
134
+ print(
135
+ "\033[92mFinished inferring and saving the tool_desc_embedding.\033[0m"
136
+ )
137
+
138
+ # Clean up intermediate variables
139
+ del all_tools_str
140
+
141
+ # Force GPU memory cleanup
142
+ if torch.cuda.is_available():
143
+ torch.cuda.empty_cache()
144
+ torch.cuda.synchronize()
145
+
146
+ # Force CPU memory cleanup
147
+ gc.collect()
148
+
149
+ print(
150
+ "\033[92mMemory cleanup completed. Embeddings are ready for use.\033[0m"
151
+ )
152
+
153
+ def rag_infer(self, query, top_k=5):
154
+ """
155
+ Perform RAG inference to find the most relevant tools for a given query.
156
+
157
+ Uses semantic similarity between the query embedding and pre-computed tool embeddings
158
+ to identify the most relevant tools.
159
+
160
+ Args:
161
+ query (str): User query or description of desired functionality
162
+ top_k (int, optional): Number of top tools to return. Defaults to 5.
163
+
164
+ Returns:
165
+ list: List of top-k tool names ranked by relevance to the query
166
+
167
+ Raises:
168
+ SystemExit: If tool_desc_embedding is not loaded
169
+ """
170
+ torch.cuda.empty_cache()
171
+ queries = [query]
172
+ query_embeddings = self.rag_model.encode(
173
+ queries, prompt="", normalize_embeddings=True
174
+ )
175
+ if self.tool_desc_embedding is None:
176
+ print("No tool_desc_embedding")
177
+ exit()
178
+ scores = self.rag_model.similarity(query_embeddings, self.tool_desc_embedding)
179
+ top_k = min(top_k, len(self.tool_name))
180
+ top_k_indices = torch.topk(scores, top_k).indices.tolist()[0]
181
+ top_k_tool_names = [self.tool_name[i] for i in top_k_indices]
182
+ return top_k_tool_names
183
+
184
+ def find_tools(
185
+ self,
186
+ message=None,
187
+ picked_tool_names=None,
188
+ rag_num=5,
189
+ return_call_result=False,
190
+ categories=None,
191
+ ):
192
+ """
193
+ Find relevant tools based on a message or pre-selected tool names.
194
+
195
+ This method either uses RAG inference to find tools based on a message or processes
196
+ a list of pre-selected tool names. It filters out special tools and returns tool
197
+ prompts suitable for use in agent workflows.
198
+
199
+ Args:
200
+ message (str, optional): Query message to find tools for. Required if picked_tool_names is None.
201
+ picked_tool_names (list, optional): Pre-selected tool names to process. Required if message is None.
202
+ rag_num (int, optional): Number of tools to return after filtering. Defaults to 5.
203
+ return_call_result (bool, optional): If True, returns both prompts and tool names. Defaults to False.
204
+ categories (list, optional): List of tool categories to filter by. Currently not implemented for embedding-based search.
205
+
206
+ Returns:
207
+ str or tuple:
208
+ - If return_call_result is False: Tool prompts as a formatted string
209
+ - If return_call_result is True: Tuple of (tool_prompts, tool_names)
210
+
211
+ Raises:
212
+ AssertionError: If both message and picked_tool_names are None
213
+ """
214
+ extra_factor = 1.5 # Factor to retrieve more than rag_num
215
+ if picked_tool_names is None:
216
+ assert picked_tool_names is not None or message is not None
217
+ picked_tool_names = self.rag_infer(
218
+ message, top_k=int(rag_num * extra_factor)
219
+ )
220
+
221
+ picked_tool_names_no_special = []
222
+ for tool in picked_tool_names:
223
+ if tool not in self.exclude_tools:
224
+ picked_tool_names_no_special.append(tool)
225
+ picked_tool_names_no_special = picked_tool_names_no_special[:rag_num]
226
+ picked_tool_names = picked_tool_names_no_special[:rag_num]
227
+
228
+ picked_tools = self.tooluniverse.get_tool_by_name(picked_tool_names)
229
+ picked_tools_prompt = self.tooluniverse.prepare_tool_prompts(picked_tools)
230
+ if return_call_result:
231
+ return picked_tools_prompt, picked_tool_names
232
+ return picked_tools_prompt
233
+
234
+ def run(self, arguments):
235
+ """
236
+ Run the tool finder with given arguments following the standard tool interface.
237
+
238
+ This is the main entry point for using ToolFinderEmbedding as a standard tool.
239
+ It extracts parameters from the arguments dictionary and delegates to find_tools().
240
+
241
+ Args:
242
+ arguments (dict): Dictionary containing:
243
+ - description (str, optional): Query message to find tools for (maps to 'message')
244
+ - limit (int, optional): Number of tools to return (maps to 'rag_num'). Defaults to 5.
245
+ - picked_tool_names (list, optional): Pre-selected tool names to process
246
+ - return_call_result (bool, optional): Whether to return both prompts and names. Defaults to False.
247
+ - categories (list, optional): List of tool categories to filter by
248
+ """
249
+ import copy
250
+
251
+ arguments = copy.deepcopy(arguments)
252
+
253
+ # Extract parameters from arguments with defaults
254
+ message = arguments.get("description", None)
255
+ rag_num = arguments.get("limit", 5)
256
+ picked_tool_names = arguments.get("picked_tool_names", None)
257
+ return_call_result = arguments.get("return_call_result", False)
258
+ categories = arguments.get("categories", None)
259
+
260
+ # Call the existing find_tools method
261
+ return self.find_tools(
262
+ message=message,
263
+ picked_tool_names=picked_tool_names,
264
+ rag_num=rag_num,
265
+ return_call_result=return_call_result,
266
+ categories=categories,
267
+ )