aiagents4pharma 1.9.0__py3-none-any.whl → 1.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiagents4pharma/__init__.py +9 -6
- aiagents4pharma/configs/config.yaml +2 -1
- aiagents4pharma/configs/talk2biomodels/__init__.py +1 -0
- aiagents4pharma/configs/talk2biomodels/agents/t2b_agent/default.yaml +9 -3
- aiagents4pharma/configs/talk2biomodels/tools/__init__.py +4 -0
- aiagents4pharma/configs/talk2biomodels/tools/ask_question/__init__.py +3 -0
- aiagents4pharma/talk2biomodels/__init__.py +1 -0
- aiagents4pharma/talk2biomodels/agents/t2b_agent.py +14 -11
- aiagents4pharma/talk2biomodels/api/__init__.py +6 -0
- aiagents4pharma/talk2biomodels/api/kegg.py +83 -0
- aiagents4pharma/talk2biomodels/api/ols.py +72 -0
- aiagents4pharma/talk2biomodels/api/uniprot.py +35 -0
- aiagents4pharma/talk2biomodels/models/basico_model.py +29 -32
- aiagents4pharma/talk2biomodels/models/sys_bio_model.py +9 -6
- aiagents4pharma/talk2biomodels/states/state_talk2biomodels.py +24 -7
- aiagents4pharma/talk2biomodels/tests/test_api.py +57 -0
- aiagents4pharma/talk2biomodels/tests/test_ask_question.py +44 -0
- aiagents4pharma/talk2biomodels/tests/test_basico_model.py +7 -8
- aiagents4pharma/talk2biomodels/tests/test_get_annotation.py +171 -0
- aiagents4pharma/talk2biomodels/tests/test_getmodelinfo.py +26 -0
- aiagents4pharma/talk2biomodels/tests/test_integration.py +126 -0
- aiagents4pharma/talk2biomodels/tests/test_param_scan.py +68 -0
- aiagents4pharma/talk2biomodels/tests/test_query_article.py +76 -0
- aiagents4pharma/talk2biomodels/tests/test_search_models.py +28 -0
- aiagents4pharma/talk2biomodels/tests/test_simulate_model.py +39 -0
- aiagents4pharma/talk2biomodels/tests/test_steady_state.py +90 -0
- aiagents4pharma/talk2biomodels/tests/test_sys_bio_model.py +13 -7
- aiagents4pharma/talk2biomodels/tools/__init__.py +4 -0
- aiagents4pharma/talk2biomodels/tools/ask_question.py +59 -25
- aiagents4pharma/talk2biomodels/tools/get_annotation.py +304 -0
- aiagents4pharma/talk2biomodels/tools/get_modelinfo.py +5 -3
- aiagents4pharma/talk2biomodels/tools/load_arguments.py +114 -0
- aiagents4pharma/talk2biomodels/tools/parameter_scan.py +287 -0
- aiagents4pharma/talk2biomodels/tools/query_article.py +59 -0
- aiagents4pharma/talk2biomodels/tools/simulate_model.py +20 -89
- aiagents4pharma/talk2biomodels/tools/steady_state.py +167 -0
- aiagents4pharma/talk2competitors/__init__.py +5 -0
- aiagents4pharma/talk2competitors/agents/__init__.py +6 -0
- aiagents4pharma/talk2competitors/agents/main_agent.py +130 -0
- aiagents4pharma/talk2competitors/agents/s2_agent.py +75 -0
- aiagents4pharma/talk2competitors/config/__init__.py +5 -0
- aiagents4pharma/talk2competitors/config/config.py +110 -0
- aiagents4pharma/talk2competitors/state/__init__.py +5 -0
- aiagents4pharma/talk2competitors/state/state_talk2competitors.py +32 -0
- aiagents4pharma/talk2competitors/tests/__init__.py +3 -0
- aiagents4pharma/talk2competitors/tests/test_langgraph.py +274 -0
- aiagents4pharma/talk2competitors/tools/__init__.py +7 -0
- aiagents4pharma/talk2competitors/tools/s2/__init__.py +8 -0
- aiagents4pharma/talk2competitors/tools/s2/display_results.py +25 -0
- aiagents4pharma/talk2competitors/tools/s2/multi_paper_rec.py +132 -0
- aiagents4pharma/talk2competitors/tools/s2/search.py +119 -0
- aiagents4pharma/talk2competitors/tools/s2/single_paper_rec.py +141 -0
- aiagents4pharma/talk2knowledgegraphs/__init__.py +2 -1
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_enrichments.py +39 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ollama.py +117 -0
- aiagents4pharma/talk2knowledgegraphs/utils/__init__.py +5 -0
- aiagents4pharma/talk2knowledgegraphs/utils/enrichments/__init__.py +5 -0
- aiagents4pharma/talk2knowledgegraphs/utils/enrichments/enrichments.py +36 -0
- aiagents4pharma/talk2knowledgegraphs/utils/enrichments/ollama.py +123 -0
- {aiagents4pharma-1.9.0.dist-info → aiagents4pharma-1.15.0.dist-info}/METADATA +42 -23
- aiagents4pharma-1.15.0.dist-info/RECORD +102 -0
- aiagents4pharma/talk2biomodels/tests/test_langgraph.py +0 -240
- aiagents4pharma-1.9.0.dist-info/RECORD +0 -62
- {aiagents4pharma-1.9.0.dist-info → aiagents4pharma-1.15.0.dist-info}/LICENSE +0 -0
- {aiagents4pharma-1.9.0.dist-info → aiagents4pharma-1.15.0.dist-info}/WHEEL +0 -0
- {aiagents4pharma-1.9.0.dist-info → aiagents4pharma-1.15.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,132 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
|
3
|
+
"""
|
4
|
+
multi_paper_rec: Tool for getting recommendations
|
5
|
+
based on multiple papers
|
6
|
+
"""
|
7
|
+
|
8
|
+
import json
|
9
|
+
import logging
|
10
|
+
from typing import Annotated, Any, Dict, List, Optional
|
11
|
+
|
12
|
+
import pandas as pd
|
13
|
+
import requests
|
14
|
+
from langchain_core.messages import ToolMessage
|
15
|
+
from langchain_core.tools import tool
|
16
|
+
from langchain_core.tools.base import InjectedToolCallId
|
17
|
+
from langgraph.types import Command
|
18
|
+
from pydantic import BaseModel, Field
|
19
|
+
|
20
|
+
|
21
|
+
class MultiPaperRecInput(BaseModel):
|
22
|
+
"""Input schema for multiple paper recommendations tool."""
|
23
|
+
|
24
|
+
paper_ids: List[str] = Field(
|
25
|
+
description=("List of Semantic Scholar Paper IDs to get recommendations for")
|
26
|
+
)
|
27
|
+
limit: int = Field(
|
28
|
+
default=2,
|
29
|
+
description="Maximum total number of recommendations to return",
|
30
|
+
ge=1,
|
31
|
+
le=500,
|
32
|
+
)
|
33
|
+
year: Optional[str] = Field(
|
34
|
+
default=None,
|
35
|
+
description="Year range in format: YYYY for specific year, "
|
36
|
+
"YYYY- for papers after year, -YYYY for papers before year, or YYYY:YYYY for range",
|
37
|
+
)
|
38
|
+
tool_call_id: Annotated[str, InjectedToolCallId]
|
39
|
+
|
40
|
+
model_config = {"arbitrary_types_allowed": True}
|
41
|
+
|
42
|
+
|
43
|
+
@tool(args_schema=MultiPaperRecInput)
|
44
|
+
def get_multi_paper_recommendations(
|
45
|
+
paper_ids: List[str],
|
46
|
+
tool_call_id: Annotated[str, InjectedToolCallId],
|
47
|
+
limit: int = 2,
|
48
|
+
year: Optional[str] = None,
|
49
|
+
) -> Dict[str, Any]:
|
50
|
+
"""
|
51
|
+
Get paper recommendations based on multiple papers.
|
52
|
+
|
53
|
+
Args:
|
54
|
+
paper_ids (List[str]): The list of paper IDs to base recommendations on.
|
55
|
+
tool_call_id (Annotated[str, InjectedToolCallId]): The tool call ID.
|
56
|
+
limit (int, optional): The maximum number of recommendations to return. Defaults to 2.
|
57
|
+
year (str, optional): Year range for papers.
|
58
|
+
Supports formats like "2024-", "-2024", "2024:2025". Defaults to None.
|
59
|
+
|
60
|
+
Returns:
|
61
|
+
Dict[str, Any]: The recommendations and related information.
|
62
|
+
"""
|
63
|
+
logging.info("Starting multi-paper recommendations search.")
|
64
|
+
|
65
|
+
endpoint = "https://api.semanticscholar.org/recommendations/v1/papers"
|
66
|
+
headers = {"Content-Type": "application/json"}
|
67
|
+
payload = {"positivePaperIds": paper_ids, "negativePaperIds": []}
|
68
|
+
params = {
|
69
|
+
"limit": min(limit, 500),
|
70
|
+
"fields": "paperId,title,abstract,year,authors,citationCount,url",
|
71
|
+
}
|
72
|
+
|
73
|
+
# Add year parameter if provided
|
74
|
+
if year:
|
75
|
+
params["year"] = year
|
76
|
+
|
77
|
+
# Getting recommendations
|
78
|
+
response = requests.post(
|
79
|
+
endpoint,
|
80
|
+
headers=headers,
|
81
|
+
params=params,
|
82
|
+
data=json.dumps(payload),
|
83
|
+
timeout=10,
|
84
|
+
)
|
85
|
+
logging.info(
|
86
|
+
"API Response Status for multi-paper recommendations: %s", response.status_code
|
87
|
+
)
|
88
|
+
|
89
|
+
data = response.json()
|
90
|
+
recommendations = data.get("recommendedPapers", [])
|
91
|
+
|
92
|
+
# Create a dictionary to store the papers
|
93
|
+
filtered_papers = {
|
94
|
+
paper["paperId"]: {
|
95
|
+
"Title": paper.get("title", "N/A"),
|
96
|
+
"Abstract": paper.get("abstract", "N/A"),
|
97
|
+
"Year": paper.get("year", "N/A"),
|
98
|
+
"Citation Count": paper.get("citationCount", "N/A"),
|
99
|
+
"URL": paper.get("url", "N/A"),
|
100
|
+
}
|
101
|
+
for paper in recommendations
|
102
|
+
if paper.get("title") and paper.get("paperId")
|
103
|
+
}
|
104
|
+
|
105
|
+
# Create a DataFrame from the dictionary
|
106
|
+
df = pd.DataFrame.from_dict(filtered_papers, orient="index")
|
107
|
+
# print("Created DataFrame with results:")
|
108
|
+
logging.info("Created DataFrame with results: %s", df)
|
109
|
+
|
110
|
+
# Format papers for state update
|
111
|
+
papers = [
|
112
|
+
f"Paper ID: {paper_id}\n"
|
113
|
+
f"Title: {paper_data['Title']}\n"
|
114
|
+
f"Abstract: {paper_data['Abstract']}\n"
|
115
|
+
f"Year: {paper_data['Year']}\n"
|
116
|
+
f"Citations: {paper_data['Citation Count']}\n"
|
117
|
+
f"URL: {paper_data['URL']}\n"
|
118
|
+
for paper_id, paper_data in filtered_papers.items()
|
119
|
+
]
|
120
|
+
|
121
|
+
# Convert DataFrame to markdown table
|
122
|
+
markdown_table = df.to_markdown(tablefmt="grid")
|
123
|
+
logging.info("Search results: %s", papers)
|
124
|
+
|
125
|
+
return Command(
|
126
|
+
update={
|
127
|
+
"papers": filtered_papers, # Now sending the dictionary directly
|
128
|
+
"messages": [
|
129
|
+
ToolMessage(content=markdown_table, tool_call_id=tool_call_id)
|
130
|
+
],
|
131
|
+
}
|
132
|
+
)
|
@@ -0,0 +1,119 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
|
3
|
+
"""
|
4
|
+
This tool is used to search for academic papers on Semantic Scholar.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import logging
|
8
|
+
from typing import Annotated, Any, Dict, Optional
|
9
|
+
|
10
|
+
import pandas as pd
|
11
|
+
import requests
|
12
|
+
from langchain_core.messages import ToolMessage
|
13
|
+
from langchain_core.tools import tool
|
14
|
+
from langchain_core.tools.base import InjectedToolCallId
|
15
|
+
from langgraph.types import Command
|
16
|
+
from pydantic import BaseModel, Field
|
17
|
+
|
18
|
+
|
19
|
+
class SearchInput(BaseModel):
|
20
|
+
"""Input schema for the search papers tool."""
|
21
|
+
|
22
|
+
query: str = Field(
|
23
|
+
description="Search query string to find academic papers."
|
24
|
+
"Be specific and include relevant academic terms."
|
25
|
+
)
|
26
|
+
limit: int = Field(
|
27
|
+
default=2, description="Maximum number of results to return", ge=1, le=100
|
28
|
+
)
|
29
|
+
year: Optional[str] = Field(
|
30
|
+
default=None,
|
31
|
+
description="Year range in format: YYYY for specific year, "
|
32
|
+
"YYYY- for papers after year, -YYYY for papers before year, or YYYY:YYYY for range",
|
33
|
+
)
|
34
|
+
tool_call_id: Annotated[str, InjectedToolCallId]
|
35
|
+
|
36
|
+
|
37
|
+
@tool(args_schema=SearchInput)
|
38
|
+
def search_tool(
|
39
|
+
query: str,
|
40
|
+
tool_call_id: Annotated[str, InjectedToolCallId],
|
41
|
+
limit: int = 2,
|
42
|
+
year: Optional[str] = None,
|
43
|
+
) -> Dict[str, Any]:
|
44
|
+
"""
|
45
|
+
Search for academic papers on Semantic Scholar.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
query (str): The search query string to find academic papers.
|
49
|
+
tool_call_id (Annotated[str, InjectedToolCallId]): The tool call ID.
|
50
|
+
limit (int, optional): The maximum number of results to return. Defaults to 2.
|
51
|
+
year (str, optional): Year range for papers.
|
52
|
+
Supports formats like "2024-", "-2024", "2024:2025". Defaults to None.
|
53
|
+
|
54
|
+
Returns:
|
55
|
+
Dict[str, Any]: The search results and related information.
|
56
|
+
"""
|
57
|
+
print("Starting paper search...")
|
58
|
+
endpoint = "https://api.semanticscholar.org/graph/v1/paper/search"
|
59
|
+
params = {
|
60
|
+
"query": query,
|
61
|
+
"limit": min(limit, 100),
|
62
|
+
# "fields": "paperId,title,abstract,year,authors,
|
63
|
+
# citationCount,url,publicationTypes,openAccessPdf",
|
64
|
+
"fields": "paperId,title,abstract,year,authors,citationCount,url",
|
65
|
+
}
|
66
|
+
|
67
|
+
# Add year parameter if provided
|
68
|
+
if year:
|
69
|
+
params["year"] = year
|
70
|
+
|
71
|
+
response = requests.get(endpoint, params=params, timeout=10)
|
72
|
+
data = response.json()
|
73
|
+
papers = data.get("data", [])
|
74
|
+
|
75
|
+
# Create a dictionary to store the papers
|
76
|
+
filtered_papers = {
|
77
|
+
paper["paperId"]: {
|
78
|
+
"Title": paper.get("title", "N/A"),
|
79
|
+
"Abstract": paper.get("abstract", "N/A"),
|
80
|
+
"Year": paper.get("year", "N/A"),
|
81
|
+
"Citation Count": paper.get("citationCount", "N/A"),
|
82
|
+
"URL": paper.get("url", "N/A"),
|
83
|
+
# "Publication Type": paper.get("publicationTypes", ["N/A"])[0]
|
84
|
+
# if paper.get("publicationTypes")
|
85
|
+
# else "N/A",
|
86
|
+
# "Open Access PDF": paper.get("openAccessPdf", {}).get("url", "N/A")
|
87
|
+
# if paper.get("openAccessPdf") is not None
|
88
|
+
# else "N/A",
|
89
|
+
}
|
90
|
+
for paper in papers
|
91
|
+
if paper.get("title") and paper.get("authors")
|
92
|
+
}
|
93
|
+
|
94
|
+
df = pd.DataFrame(filtered_papers)
|
95
|
+
|
96
|
+
# Format papers for state update
|
97
|
+
papers = [
|
98
|
+
f"Paper ID: {paper_id}\n"
|
99
|
+
f"Title: {paper_data['Title']}\n"
|
100
|
+
f"Abstract: {paper_data['Abstract']}\n"
|
101
|
+
f"Year: {paper_data['Year']}\n"
|
102
|
+
f"Citations: {paper_data['Citation Count']}\n"
|
103
|
+
f"URL: {paper_data['URL']}\n"
|
104
|
+
# f"Publication Type: {paper_data['Publication Type']}\n"
|
105
|
+
# f"Open Access PDF: {paper_data['Open Access PDF']}"
|
106
|
+
for paper_id, paper_data in filtered_papers.items()
|
107
|
+
]
|
108
|
+
|
109
|
+
markdown_table = df.to_markdown(tablefmt="grid")
|
110
|
+
logging.info("Search results: %s", papers)
|
111
|
+
|
112
|
+
return Command(
|
113
|
+
update={
|
114
|
+
"papers": filtered_papers, # Now sending the dictionary directly
|
115
|
+
"messages": [
|
116
|
+
ToolMessage(content=markdown_table, tool_call_id=tool_call_id)
|
117
|
+
],
|
118
|
+
}
|
119
|
+
)
|
@@ -0,0 +1,141 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
|
3
|
+
"""
|
4
|
+
This tool is used to return recommendations for a single paper.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import logging
|
8
|
+
from typing import Annotated, Any, Dict, Optional
|
9
|
+
|
10
|
+
import pandas as pd
|
11
|
+
import requests
|
12
|
+
from langchain_core.messages import ToolMessage
|
13
|
+
from langchain_core.tools import tool
|
14
|
+
from langchain_core.tools.base import InjectedToolCallId
|
15
|
+
from langgraph.types import Command
|
16
|
+
from pydantic import BaseModel, Field
|
17
|
+
|
18
|
+
# Configure logging
|
19
|
+
logging.basicConfig(level=logging.INFO)
|
20
|
+
logger = logging.getLogger(__name__)
|
21
|
+
|
22
|
+
|
23
|
+
class SinglePaperRecInput(BaseModel):
|
24
|
+
"""Input schema for single paper recommendation tool."""
|
25
|
+
|
26
|
+
paper_id: str = Field(
|
27
|
+
description="Semantic Scholar Paper ID to get recommendations for (40-character string)"
|
28
|
+
)
|
29
|
+
limit: int = Field(
|
30
|
+
default=2,
|
31
|
+
description="Maximum number of recommendations to return",
|
32
|
+
ge=1,
|
33
|
+
le=500,
|
34
|
+
)
|
35
|
+
year: Optional[str] = Field(
|
36
|
+
default=None,
|
37
|
+
description="Year range in format: YYYY for specific year, "
|
38
|
+
"YYYY- for papers after year, -YYYY for papers before year, or YYYY:YYYY for range",
|
39
|
+
)
|
40
|
+
tool_call_id: Annotated[str, InjectedToolCallId]
|
41
|
+
model_config = {"arbitrary_types_allowed": True}
|
42
|
+
|
43
|
+
|
44
|
+
@tool(args_schema=SinglePaperRecInput)
|
45
|
+
def get_single_paper_recommendations(
|
46
|
+
paper_id: str,
|
47
|
+
tool_call_id: Annotated[str, InjectedToolCallId],
|
48
|
+
limit: int = 2,
|
49
|
+
year: Optional[str] = None,
|
50
|
+
) -> Dict[str, Any]:
|
51
|
+
"""
|
52
|
+
Get paper recommendations based on a single paper.
|
53
|
+
|
54
|
+
Args:
|
55
|
+
paper_id (str): The Semantic Scholar Paper ID to get recommendations for.
|
56
|
+
tool_call_id (Annotated[str, InjectedToolCallId]): The tool call ID.
|
57
|
+
limit (int, optional): The maximum number of recommendations to return. Defaults to 2.
|
58
|
+
year (str, optional): Year range for papers.
|
59
|
+
Supports formats like "2024-", "-2024", "2024:2025". Defaults to None.
|
60
|
+
|
61
|
+
Returns:
|
62
|
+
Dict[str, Any]: The recommendations and related information.
|
63
|
+
"""
|
64
|
+
logger.info("Starting single paper recommendations search.")
|
65
|
+
|
66
|
+
endpoint = (
|
67
|
+
f"https://api.semanticscholar.org/recommendations/v1/papers/forpaper/{paper_id}"
|
68
|
+
)
|
69
|
+
params = {
|
70
|
+
"limit": min(limit, 500), # Max 500 per API docs
|
71
|
+
"fields": "paperId,title,abstract,year,authors,citationCount,url",
|
72
|
+
"from": "all-cs", # Using all-cs pool as specified in docs
|
73
|
+
}
|
74
|
+
|
75
|
+
# Add year parameter if provided
|
76
|
+
if year:
|
77
|
+
params["year"] = year
|
78
|
+
|
79
|
+
response = requests.get(endpoint, params=params, timeout=10)
|
80
|
+
data = response.json()
|
81
|
+
papers = data.get("data", [])
|
82
|
+
response = requests.get(endpoint, params=params, timeout=10)
|
83
|
+
# print(f"API Response Status: {response.status_code}")
|
84
|
+
logging.info(
|
85
|
+
"API Response Status for recommendations of paper %s: %s",
|
86
|
+
paper_id,
|
87
|
+
response.status_code,
|
88
|
+
)
|
89
|
+
# print(f"Request params: {params}")
|
90
|
+
logging.info("Request params: %s", params)
|
91
|
+
|
92
|
+
data = response.json()
|
93
|
+
recommendations = data.get("recommendedPapers", [])
|
94
|
+
|
95
|
+
# Extract paper ID and title from recommendations
|
96
|
+
filtered_papers = {
|
97
|
+
paper["paperId"]: {
|
98
|
+
"Title": paper.get("title", "N/A"),
|
99
|
+
"Abstract": paper.get("abstract", "N/A"),
|
100
|
+
"Year": paper.get("year", "N/A"),
|
101
|
+
"Citation Count": paper.get("citationCount", "N/A"),
|
102
|
+
"URL": paper.get("url", "N/A"),
|
103
|
+
# "Publication Type": paper.get("publicationTypes", ["N/A"])[0]
|
104
|
+
# if paper.get("publicationTypes")
|
105
|
+
# else "N/A",
|
106
|
+
# "Open Access PDF": paper.get("openAccessPdf", {}).get("url", "N/A")
|
107
|
+
# if paper.get("openAccessPdf") is not None
|
108
|
+
# else "N/A",
|
109
|
+
}
|
110
|
+
for paper in recommendations
|
111
|
+
if paper.get("title") and paper.get("authors")
|
112
|
+
}
|
113
|
+
|
114
|
+
# Create a DataFrame for pretty printing
|
115
|
+
df = pd.DataFrame(filtered_papers)
|
116
|
+
|
117
|
+
# Format papers for state update
|
118
|
+
papers = [
|
119
|
+
f"Paper ID: {paper_id}\n"
|
120
|
+
f"Title: {paper_data['Title']}\n"
|
121
|
+
f"Abstract: {paper_data['Abstract']}\n"
|
122
|
+
f"Year: {paper_data['Year']}\n"
|
123
|
+
f"Citations: {paper_data['Citation Count']}\n"
|
124
|
+
f"URL: {paper_data['URL']}\n"
|
125
|
+
# f"Publication Type: {paper_data['Publication Type']}\n"
|
126
|
+
# f"Open Access PDF: {paper_data['Open Access PDF']}"
|
127
|
+
for paper_id, paper_data in filtered_papers.items()
|
128
|
+
]
|
129
|
+
|
130
|
+
# Convert DataFrame to markdown table
|
131
|
+
markdown_table = df.to_markdown(tablefmt="grid")
|
132
|
+
logging.info("Search results: %s", papers)
|
133
|
+
|
134
|
+
return Command(
|
135
|
+
update={
|
136
|
+
"papers": filtered_papers, # Now sending the dictionary directly
|
137
|
+
"messages": [
|
138
|
+
ToolMessage(content=markdown_table, tool_call_id=tool_call_id)
|
139
|
+
],
|
140
|
+
}
|
141
|
+
)
|
@@ -0,0 +1,39 @@
|
|
1
|
+
"""
|
2
|
+
Test cases for utils/enrichments/enrichments.py
|
3
|
+
"""
|
4
|
+
|
5
|
+
from ..utils.enrichments.enrichments import Enrichments
|
6
|
+
|
7
|
+
class TestEnrichments(Enrichments):
|
8
|
+
"""Test implementation of the Enrichments interface for testing purposes."""
|
9
|
+
|
10
|
+
def enrich_documents(self, texts: list[str]) -> list[list[float]]:
|
11
|
+
return [
|
12
|
+
f"Additional text description of {text} as the input." for text in texts
|
13
|
+
]
|
14
|
+
|
15
|
+
def enrich_documents_with_rag(self, texts, docs):
|
16
|
+
# Currently we don't have a RAG model to test this method.
|
17
|
+
# Thus, we will just call the enrich_documents method instead.
|
18
|
+
return self.enrich_documents(texts)
|
19
|
+
|
20
|
+
def test_enrich_documents():
|
21
|
+
"""Test enriching documents using the Enrichments interface."""
|
22
|
+
enrichments = TestEnrichments()
|
23
|
+
texts = ["text1", "text2"]
|
24
|
+
result = enrichments.enrich_documents(texts)
|
25
|
+
assert result == [
|
26
|
+
"Additional text description of text1 as the input.",
|
27
|
+
"Additional text description of text2 as the input.",
|
28
|
+
]
|
29
|
+
|
30
|
+
def test_enrich_documents_with_rag():
|
31
|
+
"""Test enriching documents with RAG using the Enrichments interface."""
|
32
|
+
enrichments = TestEnrichments()
|
33
|
+
texts = ["text1", "text2"]
|
34
|
+
docs = ["doc1", "doc2"]
|
35
|
+
result = enrichments.enrich_documents_with_rag(texts, docs)
|
36
|
+
assert result == [
|
37
|
+
"Additional text description of text1 as the input.",
|
38
|
+
"Additional text description of text2 as the input.",
|
39
|
+
]
|
@@ -0,0 +1,117 @@
|
|
1
|
+
"""
|
2
|
+
Test cases for utils/enrichments/ollama.py
|
3
|
+
"""
|
4
|
+
|
5
|
+
import pytest
|
6
|
+
import ollama
|
7
|
+
from ..utils.enrichments.ollama import EnrichmentWithOllama
|
8
|
+
|
9
|
+
@pytest.fixture(name="ollama_config")
|
10
|
+
def fixture_ollama_config():
|
11
|
+
"""Return a dictionary with Ollama configuration."""
|
12
|
+
return {
|
13
|
+
"model_name": "smollm2:360m",
|
14
|
+
"prompt_enrichment": """
|
15
|
+
Given the input as a list of strings, please return the list of addditional information of
|
16
|
+
each input terms using your prior knowledge.
|
17
|
+
|
18
|
+
Example:
|
19
|
+
Input: ['acetaminophen', 'aspirin']
|
20
|
+
Ouput: ['acetaminophen is a medication used to treat pain and fever',
|
21
|
+
'aspirin is a medication used to treat pain, fever, and inflammation']
|
22
|
+
|
23
|
+
Do not include any pretext as the output, only the list of strings enriched.
|
24
|
+
|
25
|
+
Input: {input}
|
26
|
+
""",
|
27
|
+
"temperature": 0.0,
|
28
|
+
"streaming": False,
|
29
|
+
}
|
30
|
+
|
31
|
+
def test_no_model_ollama(ollama_config):
|
32
|
+
"""Test the case when the Ollama model is not available."""
|
33
|
+
cfg = ollama_config
|
34
|
+
cfg_model = "smollm2:135m" # Choose a small model
|
35
|
+
|
36
|
+
# Delete the Ollama model
|
37
|
+
try:
|
38
|
+
ollama.delete(cfg_model)
|
39
|
+
except ollama.ResponseError:
|
40
|
+
pass
|
41
|
+
|
42
|
+
# Check if the model is available
|
43
|
+
with pytest.raises(
|
44
|
+
ValueError, match=f"Error: Pulled {cfg_model} model and restarted Ollama server."
|
45
|
+
):
|
46
|
+
EnrichmentWithOllama(
|
47
|
+
model_name=cfg_model,
|
48
|
+
prompt_enrichment=cfg["prompt_enrichment"],
|
49
|
+
temperature=cfg["temperature"],
|
50
|
+
streaming=cfg["streaming"],
|
51
|
+
)
|
52
|
+
ollama.delete(cfg_model)
|
53
|
+
|
54
|
+
def test_enrich_nodes_ollama(ollama_config):
|
55
|
+
"""Test the Ollama textual enrichment class for node enrichment."""
|
56
|
+
# Prepare enrichment model
|
57
|
+
cfg = ollama_config
|
58
|
+
enr_model = EnrichmentWithOllama(
|
59
|
+
model_name=cfg["model_name"],
|
60
|
+
prompt_enrichment=cfg["prompt_enrichment"],
|
61
|
+
temperature=cfg["temperature"],
|
62
|
+
streaming=cfg["streaming"],
|
63
|
+
)
|
64
|
+
|
65
|
+
# Perform enrichment for nodes
|
66
|
+
nodes = ["Adalimumab", "Infliximab"]
|
67
|
+
enriched_nodes = enr_model.enrich_documents(nodes)
|
68
|
+
# Check the enriched nodes
|
69
|
+
assert len(enriched_nodes) == 2
|
70
|
+
assert all(
|
71
|
+
enriched_nodes[i] != nodes[i] for i in range(len(nodes))
|
72
|
+
)
|
73
|
+
|
74
|
+
|
75
|
+
def test_enrich_relations_ollama(ollama_config):
|
76
|
+
"""Test the Ollama textual enrichment class for relation enrichment."""
|
77
|
+
# Prepare enrichment model
|
78
|
+
cfg = ollama_config
|
79
|
+
enr_model = EnrichmentWithOllama(
|
80
|
+
model_name=cfg["model_name"],
|
81
|
+
prompt_enrichment=cfg["prompt_enrichment"],
|
82
|
+
temperature=cfg["temperature"],
|
83
|
+
streaming=cfg["streaming"],
|
84
|
+
)
|
85
|
+
# Perform enrichment for relations
|
86
|
+
relations = [
|
87
|
+
"IL23R-gene causation disease-inflammatory bowel diseases",
|
88
|
+
"NOD2-gene causation disease-inflammatory bowel diseases",
|
89
|
+
]
|
90
|
+
enriched_relations = enr_model.enrich_documents(relations)
|
91
|
+
# Check the enriched relations
|
92
|
+
assert len(enriched_relations) == 2
|
93
|
+
assert all(
|
94
|
+
enriched_relations[i] != relations[i]
|
95
|
+
for i in range(len(relations))
|
96
|
+
)
|
97
|
+
|
98
|
+
|
99
|
+
def test_enrich_ollama_rag(ollama_config):
|
100
|
+
"""Test the Ollama textual enrichment class for enrichment with RAG (not implemented)."""
|
101
|
+
# Prepare enrichment model
|
102
|
+
cfg = ollama_config
|
103
|
+
enr_model = EnrichmentWithOllama(
|
104
|
+
model_name=cfg["model_name"],
|
105
|
+
prompt_enrichment=cfg["prompt_enrichment"],
|
106
|
+
temperature=cfg["temperature"],
|
107
|
+
streaming=cfg["streaming"],
|
108
|
+
)
|
109
|
+
# Perform enrichment for nodes
|
110
|
+
nodes = ["Adalimumab", "Infliximab"]
|
111
|
+
docs = [r"\path\to\doc1", r"\path\to\doc2"]
|
112
|
+
enriched_nodes = enr_model.enrich_documents_with_rag(nodes, docs)
|
113
|
+
# Check the enriched nodes
|
114
|
+
assert len(enriched_nodes) == 2
|
115
|
+
assert all(
|
116
|
+
enriched_nodes[i] != nodes[i] for i in range(len(nodes))
|
117
|
+
)
|
@@ -0,0 +1,36 @@
|
|
1
|
+
"""
|
2
|
+
Enrichments interface
|
3
|
+
"""
|
4
|
+
|
5
|
+
from abc import ABC, abstractmethod
|
6
|
+
|
7
|
+
class Enrichments(ABC):
|
8
|
+
"""Interface for enrichment models.
|
9
|
+
|
10
|
+
This is an interface meant for implementing text enrichment models.
|
11
|
+
|
12
|
+
Enrichment models are used to enrich node or relation features in a given knowledge graph.
|
13
|
+
"""
|
14
|
+
|
15
|
+
@abstractmethod
|
16
|
+
def enrich_documents(self, texts: list[str]) -> list[list[str]]:
|
17
|
+
"""Enrich documents.
|
18
|
+
|
19
|
+
Args:
|
20
|
+
texts: List of documents to enrich.
|
21
|
+
|
22
|
+
Returns:
|
23
|
+
List of enriched documents.
|
24
|
+
"""
|
25
|
+
|
26
|
+
@abstractmethod
|
27
|
+
def enrich_documents_with_rag(self, texts: list[str], docs: list[str]) -> list[str]:
|
28
|
+
"""Enrich documents with RAG.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
texts: List of documents to enrich.
|
32
|
+
docs: List of reference documents to enrich the input texts.
|
33
|
+
|
34
|
+
Returns:
|
35
|
+
List of enriched documents with RAG.
|
36
|
+
"""
|