tooluniverse 0.2.0__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tooluniverse might be problematic. Click here for more details.
- tooluniverse/__init__.py +340 -4
- tooluniverse/admetai_tool.py +84 -0
- tooluniverse/agentic_tool.py +563 -0
- tooluniverse/alphafold_tool.py +96 -0
- tooluniverse/base_tool.py +129 -6
- tooluniverse/boltz_tool.py +207 -0
- tooluniverse/chem_tool.py +192 -0
- tooluniverse/compose_scripts/__init__.py +1 -0
- tooluniverse/compose_scripts/biomarker_discovery.py +293 -0
- tooluniverse/compose_scripts/comprehensive_drug_discovery.py +186 -0
- tooluniverse/compose_scripts/drug_safety_analyzer.py +89 -0
- tooluniverse/compose_scripts/literature_tool.py +34 -0
- tooluniverse/compose_scripts/output_summarizer.py +279 -0
- tooluniverse/compose_scripts/tool_description_optimizer.py +681 -0
- tooluniverse/compose_scripts/tool_discover.py +705 -0
- tooluniverse/compose_scripts/tool_graph_composer.py +448 -0
- tooluniverse/compose_tool.py +371 -0
- tooluniverse/ctg_tool.py +1002 -0
- tooluniverse/custom_tool.py +81 -0
- tooluniverse/dailymed_tool.py +108 -0
- tooluniverse/data/admetai_tools.json +155 -0
- tooluniverse/data/adverse_event_tools.json +108 -0
- tooluniverse/data/agentic_tools.json +1156 -0
- tooluniverse/data/alphafold_tools.json +87 -0
- tooluniverse/data/boltz_tools.json +9 -0
- tooluniverse/data/chembl_tools.json +16 -0
- tooluniverse/data/clinicaltrials_gov_tools.json +326 -0
- tooluniverse/data/compose_tools.json +202 -0
- tooluniverse/data/dailymed_tools.json +70 -0
- tooluniverse/data/dataset_tools.json +646 -0
- tooluniverse/data/disease_target_score_tools.json +712 -0
- tooluniverse/data/efo_tools.json +17 -0
- tooluniverse/data/embedding_tools.json +319 -0
- tooluniverse/data/enrichr_tools.json +31 -0
- tooluniverse/data/europe_pmc_tools.json +22 -0
- tooluniverse/data/expert_feedback_tools.json +10 -0
- tooluniverse/data/fda_drug_adverse_event_tools.json +491 -0
- tooluniverse/data/fda_drug_labeling_tools.json +1 -1
- tooluniverse/data/fda_drugs_with_brand_generic_names_for_tool.py +76929 -148860
- tooluniverse/data/finder_tools.json +209 -0
- tooluniverse/data/gene_ontology_tools.json +113 -0
- tooluniverse/data/gwas_tools.json +1082 -0
- tooluniverse/data/hpa_tools.json +333 -0
- tooluniverse/data/humanbase_tools.json +47 -0
- tooluniverse/data/idmap_tools.json +74 -0
- tooluniverse/data/mcp_client_tools_example.json +113 -0
- tooluniverse/data/mcpautoloadertool_defaults.json +28 -0
- tooluniverse/data/medlineplus_tools.json +141 -0
- tooluniverse/data/monarch_tools.json +1 -1
- tooluniverse/data/openalex_tools.json +36 -0
- tooluniverse/data/opentarget_tools.json +1 -1
- tooluniverse/data/output_summarization_tools.json +101 -0
- tooluniverse/data/packages/bioinformatics_core_tools.json +1756 -0
- tooluniverse/data/packages/categorized_tools.txt +206 -0
- tooluniverse/data/packages/cheminformatics_tools.json +347 -0
- tooluniverse/data/packages/earth_sciences_tools.json +74 -0
- tooluniverse/data/packages/genomics_tools.json +776 -0
- tooluniverse/data/packages/image_processing_tools.json +38 -0
- tooluniverse/data/packages/machine_learning_tools.json +789 -0
- tooluniverse/data/packages/neuroscience_tools.json +62 -0
- tooluniverse/data/packages/original_tools.txt +0 -0
- tooluniverse/data/packages/physics_astronomy_tools.json +62 -0
- tooluniverse/data/packages/scientific_computing_tools.json +560 -0
- tooluniverse/data/packages/single_cell_tools.json +453 -0
- tooluniverse/data/packages/structural_biology_tools.json +396 -0
- tooluniverse/data/packages/visualization_tools.json +399 -0
- tooluniverse/data/pubchem_tools.json +215 -0
- tooluniverse/data/pubtator_tools.json +68 -0
- tooluniverse/data/rcsb_pdb_tools.json +1332 -0
- tooluniverse/data/reactome_tools.json +19 -0
- tooluniverse/data/semantic_scholar_tools.json +26 -0
- tooluniverse/data/special_tools.json +2 -25
- tooluniverse/data/tool_composition_tools.json +88 -0
- tooluniverse/data/toolfinderkeyword_defaults.json +34 -0
- tooluniverse/data/txagent_client_tools.json +9 -0
- tooluniverse/data/uniprot_tools.json +211 -0
- tooluniverse/data/url_fetch_tools.json +94 -0
- tooluniverse/data/uspto_downloader_tools.json +9 -0
- tooluniverse/data/uspto_tools.json +811 -0
- tooluniverse/data/xml_tools.json +3275 -0
- tooluniverse/dataset_tool.py +296 -0
- tooluniverse/default_config.py +165 -0
- tooluniverse/efo_tool.py +42 -0
- tooluniverse/embedding_database.py +630 -0
- tooluniverse/embedding_sync.py +396 -0
- tooluniverse/enrichr_tool.py +266 -0
- tooluniverse/europe_pmc_tool.py +52 -0
- tooluniverse/execute_function.py +1775 -95
- tooluniverse/extended_hooks.py +444 -0
- tooluniverse/gene_ontology_tool.py +194 -0
- tooluniverse/graphql_tool.py +158 -36
- tooluniverse/gwas_tool.py +358 -0
- tooluniverse/hpa_tool.py +1645 -0
- tooluniverse/humanbase_tool.py +389 -0
- tooluniverse/logging_config.py +254 -0
- tooluniverse/mcp_client_tool.py +764 -0
- tooluniverse/mcp_integration.py +413 -0
- tooluniverse/mcp_tool_registry.py +925 -0
- tooluniverse/medlineplus_tool.py +337 -0
- tooluniverse/openalex_tool.py +228 -0
- tooluniverse/openfda_adv_tool.py +283 -0
- tooluniverse/openfda_tool.py +393 -160
- tooluniverse/output_hook.py +1122 -0
- tooluniverse/package_tool.py +195 -0
- tooluniverse/pubchem_tool.py +158 -0
- tooluniverse/pubtator_tool.py +168 -0
- tooluniverse/rcsb_pdb_tool.py +38 -0
- tooluniverse/reactome_tool.py +108 -0
- tooluniverse/remote/boltz/boltz_mcp_server.py +50 -0
- tooluniverse/remote/depmap_24q2/depmap_24q2_mcp_tool.py +442 -0
- tooluniverse/remote/expert_feedback/human_expert_mcp_tools.py +2013 -0
- tooluniverse/remote/expert_feedback/simple_test.py +23 -0
- tooluniverse/remote/expert_feedback/start_web_interface.py +188 -0
- tooluniverse/remote/expert_feedback/web_only_interface.py +0 -0
- tooluniverse/remote/immune_compass/compass_tool.py +327 -0
- tooluniverse/remote/pinnacle/pinnacle_tool.py +328 -0
- tooluniverse/remote/transcriptformer/transcriptformer_tool.py +586 -0
- tooluniverse/remote/uspto_downloader/uspto_downloader_mcp_server.py +61 -0
- tooluniverse/remote/uspto_downloader/uspto_downloader_tool.py +120 -0
- tooluniverse/remote_tool.py +99 -0
- tooluniverse/restful_tool.py +53 -30
- tooluniverse/scripts/generate_tool_graph.py +408 -0
- tooluniverse/scripts/visualize_tool_graph.py +829 -0
- tooluniverse/semantic_scholar_tool.py +62 -0
- tooluniverse/smcp.py +2452 -0
- tooluniverse/smcp_server.py +975 -0
- tooluniverse/test/mcp_server_test.py +0 -0
- tooluniverse/test/test_admetai_tool.py +370 -0
- tooluniverse/test/test_agentic_tool.py +129 -0
- tooluniverse/test/test_alphafold_tool.py +71 -0
- tooluniverse/test/test_chem_tool.py +37 -0
- tooluniverse/test/test_compose_lieraturereview.py +63 -0
- tooluniverse/test/test_compose_tool.py +448 -0
- tooluniverse/test/test_dailymed.py +69 -0
- tooluniverse/test/test_dataset_tool.py +200 -0
- tooluniverse/test/test_disease_target_score.py +56 -0
- tooluniverse/test/test_drugbank_filter_examples.py +179 -0
- tooluniverse/test/test_efo.py +31 -0
- tooluniverse/test/test_enrichr_tool.py +21 -0
- tooluniverse/test/test_europe_pmc_tool.py +20 -0
- tooluniverse/test/test_fda_adv.py +95 -0
- tooluniverse/test/test_fda_drug_labeling.py +91 -0
- tooluniverse/test/test_gene_ontology_tools.py +66 -0
- tooluniverse/test/test_gwas_tool.py +139 -0
- tooluniverse/test/test_hpa.py +625 -0
- tooluniverse/test/test_humanbase_tool.py +20 -0
- tooluniverse/test/test_idmap_tools.py +61 -0
- tooluniverse/test/test_mcp_server.py +211 -0
- tooluniverse/test/test_mcp_tool.py +247 -0
- tooluniverse/test/test_medlineplus.py +220 -0
- tooluniverse/test/test_openalex_tool.py +32 -0
- tooluniverse/test/test_opentargets.py +28 -0
- tooluniverse/test/test_pubchem_tool.py +116 -0
- tooluniverse/test/test_pubtator_tool.py +37 -0
- tooluniverse/test/test_rcsb_pdb_tool.py +86 -0
- tooluniverse/test/test_reactome.py +54 -0
- tooluniverse/test/test_semantic_scholar_tool.py +24 -0
- tooluniverse/test/test_software_tools.py +147 -0
- tooluniverse/test/test_tool_description_optimizer.py +49 -0
- tooluniverse/test/test_tool_finder.py +26 -0
- tooluniverse/test/test_tool_finder_llm.py +252 -0
- tooluniverse/test/test_tools_find.py +195 -0
- tooluniverse/test/test_uniprot_tools.py +74 -0
- tooluniverse/test/test_uspto_tool.py +72 -0
- tooluniverse/test/test_xml_tool.py +113 -0
- tooluniverse/tool_finder_embedding.py +267 -0
- tooluniverse/tool_finder_keyword.py +693 -0
- tooluniverse/tool_finder_llm.py +699 -0
- tooluniverse/tool_graph_web_ui.py +955 -0
- tooluniverse/tool_registry.py +416 -0
- tooluniverse/uniprot_tool.py +155 -0
- tooluniverse/url_tool.py +253 -0
- tooluniverse/uspto_tool.py +240 -0
- tooluniverse/utils.py +369 -41
- tooluniverse/xml_tool.py +369 -0
- tooluniverse-1.0.1.dist-info/METADATA +387 -0
- tooluniverse-1.0.1.dist-info/RECORD +182 -0
- tooluniverse-1.0.1.dist-info/entry_points.txt +9 -0
- tooluniverse/generate_mcp_tools.py +0 -113
- tooluniverse/mcp_server.py +0 -3340
- tooluniverse-0.2.0.dist-info/METADATA +0 -139
- tooluniverse-0.2.0.dist-info/RECORD +0 -21
- tooluniverse-0.2.0.dist-info/entry_points.txt +0 -4
- {tooluniverse-0.2.0.dist-info → tooluniverse-1.0.1.dist-info}/WHEEL +0 -0
- {tooluniverse-0.2.0.dist-info → tooluniverse-1.0.1.dist-info}/licenses/LICENSE +0 -0
- {tooluniverse-0.2.0.dist-info → tooluniverse-1.0.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,563 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import re
|
|
5
|
+
import time
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
from typing import Any, Dict, List, Optional
|
|
8
|
+
|
|
9
|
+
import openai
|
|
10
|
+
from google import genai
|
|
11
|
+
|
|
12
|
+
from .base_tool import BaseTool
|
|
13
|
+
from .tool_registry import register_tool
|
|
14
|
+
from .logging_config import get_logger
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@register_tool("AgenticTool")
|
|
18
|
+
class AgenticTool(BaseTool):
|
|
19
|
+
"""Generic wrapper around LLM prompting supporting JSON-defined configs with prompts and input arguments."""
|
|
20
|
+
|
|
21
|
+
def __init__(self, tool_config: Dict[str, Any]):
|
|
22
|
+
super().__init__(tool_config)
|
|
23
|
+
self.logger = get_logger("AgenticTool") # Initialize logger
|
|
24
|
+
self.name: str = tool_config.get("name", "") # Add name attribute
|
|
25
|
+
self._prompt_template: str = tool_config.get("prompt", "")
|
|
26
|
+
self._input_arguments: List[str] = tool_config.get("input_arguments", [])
|
|
27
|
+
|
|
28
|
+
# Extract required arguments from parameter schema
|
|
29
|
+
parameter_info = tool_config.get("parameter", {})
|
|
30
|
+
self._required_arguments: List[str] = parameter_info.get("required", [])
|
|
31
|
+
self._argument_defaults: Dict[str, str] = {}
|
|
32
|
+
|
|
33
|
+
# Set up default values for optional arguments
|
|
34
|
+
properties = parameter_info.get("properties", {})
|
|
35
|
+
for arg in self._input_arguments:
|
|
36
|
+
if arg not in self._required_arguments:
|
|
37
|
+
prop_info = properties.get(arg, {})
|
|
38
|
+
|
|
39
|
+
# First check if there's an explicit "default" field
|
|
40
|
+
if "default" in prop_info:
|
|
41
|
+
self._argument_defaults[arg] = prop_info["default"]
|
|
42
|
+
|
|
43
|
+
# Get configuration from nested 'configs' dict or fallback to top-level
|
|
44
|
+
configs = tool_config.get("configs", {})
|
|
45
|
+
|
|
46
|
+
# Helper function to get config values with fallback
|
|
47
|
+
def get_config(key: str, default: Any) -> Any:
|
|
48
|
+
return configs.get(key, tool_config.get(key, default))
|
|
49
|
+
|
|
50
|
+
# LLM configuration
|
|
51
|
+
self._api_type: str = get_config("api_type", "CHATGPT")
|
|
52
|
+
self._model_id: str = get_config("model_id", "o1-mini")
|
|
53
|
+
self._temperature: float = get_config("temperature", 0.1)
|
|
54
|
+
self._max_new_tokens: int = get_config("max_new_tokens", 2048)
|
|
55
|
+
self._return_json: bool = get_config("return_json", False)
|
|
56
|
+
self._max_retries: int = get_config("max_retries", 5)
|
|
57
|
+
self._retry_delay: int = get_config("retry_delay", 5)
|
|
58
|
+
self.return_metadata: bool = get_config("return_metadata", True)
|
|
59
|
+
|
|
60
|
+
# Validation
|
|
61
|
+
if not self._prompt_template:
|
|
62
|
+
raise ValueError("AgenticTool requires a 'prompt' in the configuration.")
|
|
63
|
+
if not self._input_arguments:
|
|
64
|
+
raise ValueError(
|
|
65
|
+
"AgenticTool requires 'input_arguments' in the configuration."
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# Validate temperature range
|
|
69
|
+
if not 0 <= self._temperature <= 2:
|
|
70
|
+
self.logger.warning(
|
|
71
|
+
f"Temperature {self._temperature} is outside recommended range [0, 2]"
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
# Validate model compatibility
|
|
75
|
+
self._validate_model_config()
|
|
76
|
+
|
|
77
|
+
# Initialize the LLM model
|
|
78
|
+
try:
|
|
79
|
+
self._model, self._tokenizer = self._init_llm(
|
|
80
|
+
api_type=self._api_type, model_id=self._model_id
|
|
81
|
+
)
|
|
82
|
+
self.logger.debug(
|
|
83
|
+
f"Successfully initialized {self._api_type} model: {self._model_id}"
|
|
84
|
+
)
|
|
85
|
+
except Exception as e:
|
|
86
|
+
self.logger.error(f"Failed to initialize LLM model: {str(e)}")
|
|
87
|
+
raise
|
|
88
|
+
|
|
89
|
+
# ------------------------------------------------------------------ LLM utilities -----------
|
|
90
|
+
def _validate_model_config(self):
|
|
91
|
+
"""Validate model configuration parameters."""
|
|
92
|
+
supported_api_types = ["CHATGPT", "GEMINI"]
|
|
93
|
+
if self._api_type not in supported_api_types:
|
|
94
|
+
raise ValueError(
|
|
95
|
+
f"Unsupported API type: {self._api_type}. Supported types: {supported_api_types}"
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
# Validate model-specific configurations
|
|
99
|
+
# if self._api_type == "CHATGPT":
|
|
100
|
+
# supported_models = ["gpt-4o", "o1-mini", "o3-mini"]
|
|
101
|
+
# if self._model_id not in supported_models:
|
|
102
|
+
# self.logger.warning(f"Model {self._model_id} may not be supported. Supported models: {supported_models}")
|
|
103
|
+
|
|
104
|
+
# Validate token limits
|
|
105
|
+
if self._max_new_tokens <= 0:
|
|
106
|
+
raise ValueError("max_new_tokens must be positive")
|
|
107
|
+
|
|
108
|
+
if self._max_new_tokens > 8192: # Conservative limit
|
|
109
|
+
self.logger.warning(
|
|
110
|
+
f"max_new_tokens {self._max_new_tokens} is very high and may cause API issues"
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
def _init_llm(self, api_type: str, model_id: str):
|
|
114
|
+
"""Initialize the LLM model and tokenizer based on API type."""
|
|
115
|
+
if api_type == "CHATGPT":
|
|
116
|
+
if "gpt-4o" in model_id or model_id is None:
|
|
117
|
+
api_key = os.getenv("AZURE_OPENAI_API_KEY")
|
|
118
|
+
api_version = "2024-12-01-preview"
|
|
119
|
+
elif (
|
|
120
|
+
"o1-mini" in model_id or "o3-mini" in model_id or "o4-mini" in model_id
|
|
121
|
+
):
|
|
122
|
+
api_key = os.getenv("AZURE_OPENAI_API_KEY")
|
|
123
|
+
api_version = "2025-03-01-preview"
|
|
124
|
+
else:
|
|
125
|
+
self.logger.error(
|
|
126
|
+
f"Invalid model_id. Please use 'gpt-4o', 'o1-mini', or 'o3-mini'. Got: {model_id}"
|
|
127
|
+
)
|
|
128
|
+
raise ValueError(f"Unsupported model_id: {model_id}")
|
|
129
|
+
|
|
130
|
+
if not api_key:
|
|
131
|
+
raise ValueError(
|
|
132
|
+
"API key not found in environment. Please set the appropriate environment variable."
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
azure_endpoint = os.getenv(
|
|
136
|
+
"AZURE_OPENAI_ENDPOINT", "https://azure-ai.hms.edu"
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
from openai import AzureOpenAI
|
|
140
|
+
|
|
141
|
+
self.logger.debug(
|
|
142
|
+
"Initializing AzureOpenAI client with endpoint:", azure_endpoint
|
|
143
|
+
)
|
|
144
|
+
self.logger.debug("Using API version:", api_version)
|
|
145
|
+
model_client = AzureOpenAI(
|
|
146
|
+
azure_endpoint=azure_endpoint,
|
|
147
|
+
api_key=api_key,
|
|
148
|
+
api_version=api_version,
|
|
149
|
+
)
|
|
150
|
+
model = {
|
|
151
|
+
"model": model_client,
|
|
152
|
+
"model_name": model_id,
|
|
153
|
+
"api_version": api_version,
|
|
154
|
+
}
|
|
155
|
+
tokenizer = None
|
|
156
|
+
elif api_type == "GEMINI":
|
|
157
|
+
api_key = os.getenv("GEMINI_API_KEY")
|
|
158
|
+
if not api_key:
|
|
159
|
+
raise ValueError("GEMINI_API_KEY not found in environment variables")
|
|
160
|
+
|
|
161
|
+
model = genai.Client(api_key=api_key)
|
|
162
|
+
tokenizer = None
|
|
163
|
+
else:
|
|
164
|
+
raise ValueError(f"Unsupported API type: {api_type}")
|
|
165
|
+
|
|
166
|
+
return model, tokenizer
|
|
167
|
+
|
|
168
|
+
def _chatgpt_infer(
|
|
169
|
+
self,
|
|
170
|
+
model: Dict[str, Any],
|
|
171
|
+
messages: List[Dict[str, str]],
|
|
172
|
+
temperature: float = 0.1,
|
|
173
|
+
max_new_tokens: int = 2048,
|
|
174
|
+
return_json: bool = False,
|
|
175
|
+
max_retries: int = 5,
|
|
176
|
+
retry_delay: int = 5,
|
|
177
|
+
custom_format=None,
|
|
178
|
+
) -> Optional[str]:
|
|
179
|
+
"""Inference function for ChatGPT models including o1-mini and o3-mini."""
|
|
180
|
+
model_client = model["model"]
|
|
181
|
+
model_name = model["model_name"]
|
|
182
|
+
|
|
183
|
+
retries = 0
|
|
184
|
+
import traceback
|
|
185
|
+
|
|
186
|
+
if custom_format is not None:
|
|
187
|
+
response_format = custom_format
|
|
188
|
+
call_function = model_client.chat.completions.parse
|
|
189
|
+
elif return_json:
|
|
190
|
+
response_format = {"type": "json_object"}
|
|
191
|
+
call_function = model_client.chat.completions.create
|
|
192
|
+
else:
|
|
193
|
+
response_format = None
|
|
194
|
+
call_function = model_client.chat.completions.create
|
|
195
|
+
while retries < max_retries:
|
|
196
|
+
try:
|
|
197
|
+
if "gpt-4o" in model_name:
|
|
198
|
+
responses = call_function(
|
|
199
|
+
model=model_name,
|
|
200
|
+
messages=messages,
|
|
201
|
+
temperature=temperature,
|
|
202
|
+
max_tokens=max_new_tokens,
|
|
203
|
+
response_format=response_format,
|
|
204
|
+
)
|
|
205
|
+
elif (
|
|
206
|
+
"o1-mini" in model_name
|
|
207
|
+
or "o3-mini" in model_name
|
|
208
|
+
or "o4-mini" in model_name
|
|
209
|
+
):
|
|
210
|
+
responses = call_function(
|
|
211
|
+
model=model_name,
|
|
212
|
+
messages=messages,
|
|
213
|
+
max_completion_tokens=max_new_tokens,
|
|
214
|
+
response_format=response_format,
|
|
215
|
+
)
|
|
216
|
+
if custom_format is not None:
|
|
217
|
+
response = responses.choices[0].message.parsed.model_dump()
|
|
218
|
+
else:
|
|
219
|
+
response = responses.choices[0].message.content
|
|
220
|
+
# print("\033[92m" + response + "\033[0m")
|
|
221
|
+
# usage = responses.usage
|
|
222
|
+
# print("\033[95m" + str(usage) + "\033[0m")
|
|
223
|
+
return response
|
|
224
|
+
except openai.RateLimitError:
|
|
225
|
+
self.logger.warning(
|
|
226
|
+
f"Rate limit exceeded. Retrying in {retry_delay} seconds..."
|
|
227
|
+
)
|
|
228
|
+
retries += 1
|
|
229
|
+
time.sleep(retry_delay * retries)
|
|
230
|
+
except Exception as e:
|
|
231
|
+
self.logger.error(f"An error occurred: {e}")
|
|
232
|
+
traceback.print_exc()
|
|
233
|
+
break
|
|
234
|
+
self.logger.error("Max retries exceeded. Unable to complete the request.")
|
|
235
|
+
return None
|
|
236
|
+
|
|
237
|
+
def _gemini_infer(
|
|
238
|
+
self,
|
|
239
|
+
model: Any,
|
|
240
|
+
messages: List[Dict[str, str]],
|
|
241
|
+
temperature: float = 0.1,
|
|
242
|
+
max_new_tokens: int = 2048,
|
|
243
|
+
return_json: bool = False,
|
|
244
|
+
max_retries: int = 5,
|
|
245
|
+
retry_delay: int = 5,
|
|
246
|
+
model_name: str = "gemini-2.0-flash",
|
|
247
|
+
) -> Optional[str]:
|
|
248
|
+
"""Inference function for Gemini models."""
|
|
249
|
+
retries = 0
|
|
250
|
+
contents = ""
|
|
251
|
+
for message in messages:
|
|
252
|
+
if message["role"] == "user" or message["role"] == "system":
|
|
253
|
+
contents += f"{message['content']}\n"
|
|
254
|
+
elif message["role"] == "assistant":
|
|
255
|
+
raise ValueError(
|
|
256
|
+
"Gemini model does not support assistant role in messages for now in the code."
|
|
257
|
+
)
|
|
258
|
+
else:
|
|
259
|
+
raise ValueError(
|
|
260
|
+
"Invalid role in messages. Only 'user' and 'system' roles are supported."
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
if return_json:
|
|
264
|
+
raise ValueError(
|
|
265
|
+
"Gemini model does not support JSON format for now in the code."
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
while retries < max_retries:
|
|
269
|
+
try:
|
|
270
|
+
response = model.models.generate_content(
|
|
271
|
+
model=model_name,
|
|
272
|
+
contents=contents,
|
|
273
|
+
config=genai.types.GenerateContentConfig(
|
|
274
|
+
max_output_tokens=max_new_tokens,
|
|
275
|
+
temperature=temperature,
|
|
276
|
+
),
|
|
277
|
+
)
|
|
278
|
+
return response.text
|
|
279
|
+
except openai.RateLimitError:
|
|
280
|
+
self.logger.warning(
|
|
281
|
+
f"Rate limit exceeded. Retrying in {retry_delay} seconds..."
|
|
282
|
+
)
|
|
283
|
+
retries += 1
|
|
284
|
+
time.sleep(retry_delay * retries)
|
|
285
|
+
except Exception as e:
|
|
286
|
+
self.logger.error(f"An error occurred: {e}")
|
|
287
|
+
break
|
|
288
|
+
return None
|
|
289
|
+
|
|
290
|
+
# ------------------------------------------------------------------ public API --------------
|
|
291
|
+
def run(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
|
292
|
+
"""Execute the tool by formatting the prompt with input arguments and querying the LLM."""
|
|
293
|
+
start_time = datetime.now()
|
|
294
|
+
|
|
295
|
+
try:
|
|
296
|
+
# Validate that all required input arguments are provided
|
|
297
|
+
missing_required_args = [
|
|
298
|
+
arg for arg in self._required_arguments if arg not in arguments
|
|
299
|
+
]
|
|
300
|
+
if missing_required_args:
|
|
301
|
+
raise ValueError(
|
|
302
|
+
f"Missing required input arguments: {missing_required_args}"
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
# Add default values for optional arguments that are missing
|
|
306
|
+
for arg in self._input_arguments:
|
|
307
|
+
if arg not in arguments:
|
|
308
|
+
if arg in self._argument_defaults:
|
|
309
|
+
arguments[arg] = self._argument_defaults[arg]
|
|
310
|
+
else:
|
|
311
|
+
arguments[arg] = "" # Default to empty string for optional args
|
|
312
|
+
|
|
313
|
+
# Validate argument types and content
|
|
314
|
+
self._validate_arguments(arguments)
|
|
315
|
+
|
|
316
|
+
# Format the prompt template with the provided arguments
|
|
317
|
+
formatted_prompt = self._format_prompt(arguments)
|
|
318
|
+
|
|
319
|
+
# Prepare messages for the LLM
|
|
320
|
+
messages = [{"role": "user", "content": formatted_prompt}]
|
|
321
|
+
custom_format = arguments.get("response_format", None)
|
|
322
|
+
# Call the appropriate LLM function based on API type
|
|
323
|
+
response = self._call_llm(messages, custom_format=custom_format)
|
|
324
|
+
|
|
325
|
+
end_time = datetime.now()
|
|
326
|
+
execution_time = (end_time - start_time).total_seconds()
|
|
327
|
+
|
|
328
|
+
if self.return_metadata:
|
|
329
|
+
return {
|
|
330
|
+
"success": True,
|
|
331
|
+
"result": response,
|
|
332
|
+
"metadata": {
|
|
333
|
+
"prompt_used": (
|
|
334
|
+
formatted_prompt
|
|
335
|
+
if len(formatted_prompt) < 1000
|
|
336
|
+
else f"{formatted_prompt[:1000]}..."
|
|
337
|
+
),
|
|
338
|
+
"input_arguments": {
|
|
339
|
+
arg: arguments.get(arg) for arg in self._input_arguments
|
|
340
|
+
},
|
|
341
|
+
"model_info": {
|
|
342
|
+
"api_type": self._api_type,
|
|
343
|
+
"model_id": self._model_id,
|
|
344
|
+
"temperature": self._temperature,
|
|
345
|
+
"max_new_tokens": self._max_new_tokens,
|
|
346
|
+
},
|
|
347
|
+
"execution_time_seconds": execution_time,
|
|
348
|
+
"timestamp": start_time.isoformat(),
|
|
349
|
+
},
|
|
350
|
+
}
|
|
351
|
+
else:
|
|
352
|
+
return response
|
|
353
|
+
|
|
354
|
+
except Exception as e:
|
|
355
|
+
end_time = datetime.now()
|
|
356
|
+
execution_time = (end_time - start_time).total_seconds()
|
|
357
|
+
|
|
358
|
+
self.logger.error(f"Error executing {self.name}: {str(e)}")
|
|
359
|
+
|
|
360
|
+
if self.return_metadata:
|
|
361
|
+
return {
|
|
362
|
+
"success": False,
|
|
363
|
+
"error": str(e),
|
|
364
|
+
"error_type": type(e).__name__,
|
|
365
|
+
"metadata": {
|
|
366
|
+
"prompt_used": (
|
|
367
|
+
formatted_prompt
|
|
368
|
+
if "formatted_prompt" in locals()
|
|
369
|
+
else "Failed to format prompt"
|
|
370
|
+
),
|
|
371
|
+
"input_arguments": {
|
|
372
|
+
arg: arguments.get(arg) for arg in self._input_arguments
|
|
373
|
+
},
|
|
374
|
+
"model_info": {
|
|
375
|
+
"api_type": self._api_type,
|
|
376
|
+
"model_id": self._model_id,
|
|
377
|
+
},
|
|
378
|
+
"execution_time_seconds": execution_time,
|
|
379
|
+
"timestamp": start_time.isoformat(),
|
|
380
|
+
},
|
|
381
|
+
}
|
|
382
|
+
else:
|
|
383
|
+
return "error: " + str(e) + " error_type: " + type(e).__name__
|
|
384
|
+
|
|
385
|
+
# ------------------------------------------------------------------ helpers -----------------
|
|
386
|
+
def _validate_arguments(self, arguments: Dict[str, Any]):
|
|
387
|
+
"""Validate input arguments for common issues."""
|
|
388
|
+
for arg_name, value in arguments.items():
|
|
389
|
+
if arg_name in self._input_arguments:
|
|
390
|
+
# Check for empty strings only for required arguments
|
|
391
|
+
if isinstance(value, str) and not value.strip():
|
|
392
|
+
if arg_name in self._required_arguments:
|
|
393
|
+
raise ValueError(
|
|
394
|
+
f"Required argument '{arg_name}' cannot be empty"
|
|
395
|
+
)
|
|
396
|
+
# Optional arguments can be empty, so we skip the check
|
|
397
|
+
|
|
398
|
+
# Check for extremely long inputs that might cause issues - silent validation
|
|
399
|
+
if (
|
|
400
|
+
isinstance(value, str) and len(value) > 100000
|
|
401
|
+
): # 100k character limit
|
|
402
|
+
pass # Could potentially cause API issues but no need to spam output
|
|
403
|
+
|
|
404
|
+
def _format_prompt(self, arguments: Dict[str, Any]) -> str:
|
|
405
|
+
"""Format the prompt template with the provided arguments."""
|
|
406
|
+
prompt = self._prompt_template
|
|
407
|
+
|
|
408
|
+
# Track which placeholders we actually replace
|
|
409
|
+
replaced_placeholders = set()
|
|
410
|
+
|
|
411
|
+
# Replace placeholders in the format {argument_name} with actual values
|
|
412
|
+
for arg_name in self._input_arguments:
|
|
413
|
+
placeholder = f"{{{arg_name}}}"
|
|
414
|
+
value = arguments.get(arg_name, "")
|
|
415
|
+
|
|
416
|
+
if placeholder in prompt:
|
|
417
|
+
replaced_placeholders.add(arg_name)
|
|
418
|
+
# Handle special characters and formatting
|
|
419
|
+
if isinstance(value, str):
|
|
420
|
+
# Simple replacement without complex escaping that was causing issues
|
|
421
|
+
prompt = prompt.replace(placeholder, str(value))
|
|
422
|
+
else:
|
|
423
|
+
prompt = prompt.replace(placeholder, str(value))
|
|
424
|
+
|
|
425
|
+
# Check for unreplaced expected placeholders (only check our input arguments)
|
|
426
|
+
# _unreplaced_expected = [
|
|
427
|
+
# arg for arg in self._input_arguments if arg not in replaced_placeholders
|
|
428
|
+
# ]
|
|
429
|
+
|
|
430
|
+
# Silent handling - no debug output needed for template patterns in JSON content
|
|
431
|
+
|
|
432
|
+
return prompt
|
|
433
|
+
|
|
434
|
+
def _call_llm(self, messages: List[Dict[str, str]], custom_format=None) -> str:
|
|
435
|
+
"""Make the actual LLM API call using the appropriate function."""
|
|
436
|
+
if self._api_type == "CHATGPT":
|
|
437
|
+
response = self._chatgpt_infer(
|
|
438
|
+
model=self._model,
|
|
439
|
+
messages=messages,
|
|
440
|
+
temperature=self._temperature,
|
|
441
|
+
max_new_tokens=self._max_new_tokens,
|
|
442
|
+
return_json=self._return_json,
|
|
443
|
+
max_retries=self._max_retries,
|
|
444
|
+
retry_delay=self._retry_delay,
|
|
445
|
+
custom_format=custom_format,
|
|
446
|
+
)
|
|
447
|
+
if response is None:
|
|
448
|
+
raise Exception("LLM API call failed after maximum retries")
|
|
449
|
+
return response
|
|
450
|
+
|
|
451
|
+
elif self._api_type == "GEMINI":
|
|
452
|
+
response = self._gemini_infer(
|
|
453
|
+
model=self._model,
|
|
454
|
+
messages=messages,
|
|
455
|
+
temperature=self._temperature,
|
|
456
|
+
max_new_tokens=self._max_new_tokens,
|
|
457
|
+
return_json=self._return_json,
|
|
458
|
+
max_retries=self._max_retries,
|
|
459
|
+
retry_delay=self._retry_delay,
|
|
460
|
+
)
|
|
461
|
+
if response is None:
|
|
462
|
+
raise Exception("Gemini API call failed after maximum retries")
|
|
463
|
+
return response
|
|
464
|
+
|
|
465
|
+
else:
|
|
466
|
+
raise ValueError(f"Unsupported API type: {self._api_type}")
|
|
467
|
+
|
|
468
|
+
def get_prompt_preview(self, arguments: Dict[str, Any]) -> str:
|
|
469
|
+
"""Preview how the prompt will look with the given arguments (useful for debugging)."""
|
|
470
|
+
try:
|
|
471
|
+
# Create a copy to avoid modifying the original arguments
|
|
472
|
+
args_copy = arguments.copy()
|
|
473
|
+
|
|
474
|
+
# Validate that all required input arguments are provided
|
|
475
|
+
missing_required_args = [
|
|
476
|
+
arg for arg in self._required_arguments if arg not in args_copy
|
|
477
|
+
]
|
|
478
|
+
if missing_required_args:
|
|
479
|
+
raise ValueError(
|
|
480
|
+
f"Missing required input arguments: {missing_required_args}"
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
# Add default values for optional arguments that are missing
|
|
484
|
+
for arg in self._input_arguments:
|
|
485
|
+
if arg not in args_copy:
|
|
486
|
+
if arg in self._argument_defaults:
|
|
487
|
+
args_copy[arg] = self._argument_defaults[arg]
|
|
488
|
+
else:
|
|
489
|
+
args_copy[arg] = "" # Default to empty string for optional args
|
|
490
|
+
|
|
491
|
+
return self._format_prompt(args_copy)
|
|
492
|
+
except Exception as e:
|
|
493
|
+
return f"Error formatting prompt: {str(e)}"
|
|
494
|
+
|
|
495
|
+
def get_model_info(self) -> Dict[str, Any]:
|
|
496
|
+
"""Get comprehensive information about the configured model."""
|
|
497
|
+
return {
|
|
498
|
+
"api_type": self._api_type,
|
|
499
|
+
"model_id": self._model_id,
|
|
500
|
+
"temperature": self._temperature,
|
|
501
|
+
"max_new_tokens": self._max_new_tokens,
|
|
502
|
+
"return_json": self._return_json,
|
|
503
|
+
"max_retries": self._max_retries,
|
|
504
|
+
"retry_delay": self._retry_delay,
|
|
505
|
+
}
|
|
506
|
+
|
|
507
|
+
def get_prompt_template(self) -> str:
|
|
508
|
+
"""Get the raw prompt template."""
|
|
509
|
+
return self._prompt_template
|
|
510
|
+
|
|
511
|
+
def get_input_arguments(self) -> List[str]:
|
|
512
|
+
"""Get the list of required input arguments."""
|
|
513
|
+
return self._input_arguments.copy()
|
|
514
|
+
|
|
515
|
+
def validate_configuration(self) -> Dict[str, Any]:
|
|
516
|
+
"""Validate the tool configuration and return validation results."""
|
|
517
|
+
validation_results = {"valid": True, "warnings": [], "errors": []}
|
|
518
|
+
|
|
519
|
+
try:
|
|
520
|
+
self._validate_model_config()
|
|
521
|
+
except ValueError as e:
|
|
522
|
+
validation_results["valid"] = False
|
|
523
|
+
validation_results["errors"].append(str(e))
|
|
524
|
+
|
|
525
|
+
# Check prompt template
|
|
526
|
+
if not self._prompt_template:
|
|
527
|
+
validation_results["valid"] = False
|
|
528
|
+
validation_results["errors"].append("Missing prompt template")
|
|
529
|
+
|
|
530
|
+
# Check for placeholder consistency
|
|
531
|
+
placeholders_in_prompt = set(re.findall(r"\{([^}]+)\}", self._prompt_template))
|
|
532
|
+
required_args = set(self._input_arguments)
|
|
533
|
+
|
|
534
|
+
missing_in_prompt = required_args - placeholders_in_prompt
|
|
535
|
+
extra_in_prompt = placeholders_in_prompt - required_args
|
|
536
|
+
|
|
537
|
+
if missing_in_prompt:
|
|
538
|
+
validation_results["warnings"].append(
|
|
539
|
+
f"Arguments not used in prompt: {missing_in_prompt}"
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
if extra_in_prompt:
|
|
543
|
+
validation_results["warnings"].append(
|
|
544
|
+
f"Placeholders in prompt without corresponding arguments: {extra_in_prompt}"
|
|
545
|
+
)
|
|
546
|
+
|
|
547
|
+
return validation_results
|
|
548
|
+
|
|
549
|
+
def estimate_token_usage(self, arguments: Dict[str, Any]) -> Dict[str, int]:
|
|
550
|
+
"""Estimate token usage for the given arguments (rough approximation)."""
|
|
551
|
+
prompt = self._format_prompt(arguments)
|
|
552
|
+
|
|
553
|
+
# Rough token estimation (4 characters ≈ 1 token for English text)
|
|
554
|
+
estimated_input_tokens = len(prompt) // 4
|
|
555
|
+
estimated_max_output_tokens = self._max_new_tokens
|
|
556
|
+
estimated_total_tokens = estimated_input_tokens + estimated_max_output_tokens
|
|
557
|
+
|
|
558
|
+
return {
|
|
559
|
+
"estimated_input_tokens": estimated_input_tokens,
|
|
560
|
+
"max_output_tokens": estimated_max_output_tokens,
|
|
561
|
+
"estimated_total_tokens": estimated_total_tokens,
|
|
562
|
+
"prompt_length_chars": len(prompt),
|
|
563
|
+
}
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
import requests
|
|
2
|
+
import re
|
|
3
|
+
from typing import Dict, Any, List
|
|
4
|
+
from .base_tool import BaseTool
|
|
5
|
+
from .tool_registry import register_tool
|
|
6
|
+
|
|
7
|
+
ALPHAFOLD_BASE_URL = "https://alphafold.ebi.ac.uk/api"
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@register_tool("AlphaFoldRESTTool")
|
|
11
|
+
class AlphaFoldRESTTool(BaseTool):
|
|
12
|
+
"""
|
|
13
|
+
AlphaFold Protein Structure Database API tool.
|
|
14
|
+
Supports queries by UniProt accession ID.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, tool_config):
|
|
18
|
+
super().__init__(tool_config)
|
|
19
|
+
fields = tool_config.get("fields", {})
|
|
20
|
+
parameter = tool_config.get("parameter", {})
|
|
21
|
+
|
|
22
|
+
self.endpoint_template: str = fields["endpoint"]
|
|
23
|
+
self.param_schema: Dict[str, Any] = parameter.get("properties", {})
|
|
24
|
+
self.required: List[str] = parameter.get("required", [])
|
|
25
|
+
self.output_format: str = fields.get("return_format", "JSON")
|
|
26
|
+
|
|
27
|
+
def _build_url(self, arguments: Dict[str, Any]) -> Dict[str, Any] | str:
|
|
28
|
+
url_path = self.endpoint_template
|
|
29
|
+
placeholders = re.findall(r"\{([^{}]+)\}", url_path)
|
|
30
|
+
for ph in placeholders:
|
|
31
|
+
if ph not in arguments or arguments[ph] is None:
|
|
32
|
+
return {"error": f"Missing required parameter '{ph}'"}
|
|
33
|
+
url_path = url_path.replace(f"{{{ph}}}", str(arguments[ph]))
|
|
34
|
+
return ALPHAFOLD_BASE_URL + url_path
|
|
35
|
+
|
|
36
|
+
def run(self, arguments: Dict[str, Any]):
|
|
37
|
+
# Validate required params
|
|
38
|
+
missing = [k for k in self.required if k not in arguments]
|
|
39
|
+
if missing:
|
|
40
|
+
return {"error": f"Missing required parameter(s): {', '.join(missing)}"}
|
|
41
|
+
|
|
42
|
+
url = self._build_url(arguments)
|
|
43
|
+
if isinstance(url, dict) and "error" in url:
|
|
44
|
+
return url
|
|
45
|
+
try:
|
|
46
|
+
resp = requests.get(
|
|
47
|
+
url,
|
|
48
|
+
timeout=30,
|
|
49
|
+
headers={
|
|
50
|
+
"Accept": "application/json",
|
|
51
|
+
"User-Agent": "ToolUniverse/AlphaFold",
|
|
52
|
+
},
|
|
53
|
+
)
|
|
54
|
+
except Exception as e:
|
|
55
|
+
return {"error": "Request to AlphaFold API failed", "detail": str(e)}
|
|
56
|
+
|
|
57
|
+
# Handle HTTP errors cleanly
|
|
58
|
+
if resp.status_code == 404:
|
|
59
|
+
return {
|
|
60
|
+
"error": "No AlphaFold prediction found",
|
|
61
|
+
"uniprot_id": arguments.get("uniprot_id"),
|
|
62
|
+
}
|
|
63
|
+
if resp.status_code != 200:
|
|
64
|
+
return {
|
|
65
|
+
"error": f"AlphaFold API returned {resp.status_code}",
|
|
66
|
+
"detail": resp.text,
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
# Parse JSON
|
|
70
|
+
if self.output_format.upper() == "JSON":
|
|
71
|
+
try:
|
|
72
|
+
data = resp.json()
|
|
73
|
+
if not data:
|
|
74
|
+
return {
|
|
75
|
+
"error": "AlphaFold returned an empty response",
|
|
76
|
+
"uniprot_id": arguments.get("uniprot_id"),
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
return {
|
|
80
|
+
"data": data,
|
|
81
|
+
"metadata": {
|
|
82
|
+
"count": len(data) if isinstance(data, list) else 1,
|
|
83
|
+
"source": "AlphaFold Protein Structure DB",
|
|
84
|
+
"endpoint": url,
|
|
85
|
+
"query": arguments,
|
|
86
|
+
},
|
|
87
|
+
}
|
|
88
|
+
except Exception as e:
|
|
89
|
+
return {
|
|
90
|
+
"error": "Failed to parse JSON response",
|
|
91
|
+
"raw": resp.text,
|
|
92
|
+
"detail": str(e),
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
# Fallback if non-JSON format
|
|
96
|
+
return {"data": resp.text, "metadata": {"endpoint": url}}
|