aiagents4pharma 1.8.0__py3-none-any.whl → 1.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiagents4pharma/__init__.py +9 -5
- aiagents4pharma/configs/__init__.py +5 -0
- aiagents4pharma/configs/config.yaml +4 -0
- aiagents4pharma/configs/talk2biomodels/__init__.py +6 -0
- aiagents4pharma/configs/talk2biomodels/agents/__init__.py +5 -0
- aiagents4pharma/configs/talk2biomodels/agents/t2b_agent/__init__.py +3 -0
- aiagents4pharma/configs/talk2biomodels/agents/t2b_agent/default.yaml +14 -0
- aiagents4pharma/configs/talk2biomodels/tools/__init__.py +4 -0
- aiagents4pharma/configs/talk2biomodels/tools/ask_question/__init__.py +3 -0
- aiagents4pharma/talk2biomodels/__init__.py +3 -0
- aiagents4pharma/talk2biomodels/agents/__init__.py +5 -0
- aiagents4pharma/talk2biomodels/agents/t2b_agent.py +96 -0
- aiagents4pharma/talk2biomodels/api/__init__.py +6 -0
- aiagents4pharma/talk2biomodels/api/kegg.py +83 -0
- aiagents4pharma/talk2biomodels/api/ols.py +72 -0
- aiagents4pharma/talk2biomodels/api/uniprot.py +35 -0
- aiagents4pharma/talk2biomodels/models/basico_model.py +29 -32
- aiagents4pharma/talk2biomodels/models/sys_bio_model.py +9 -6
- aiagents4pharma/talk2biomodels/states/__init__.py +5 -0
- aiagents4pharma/talk2biomodels/states/state_talk2biomodels.py +41 -0
- aiagents4pharma/talk2biomodels/tests/__init__.py +3 -0
- aiagents4pharma/talk2biomodels/tests/test_api.py +57 -0
- aiagents4pharma/talk2biomodels/tests/test_ask_question.py +44 -0
- aiagents4pharma/talk2biomodels/tests/test_basico_model.py +54 -0
- aiagents4pharma/talk2biomodels/tests/test_get_annotation.py +171 -0
- aiagents4pharma/talk2biomodels/tests/test_getmodelinfo.py +26 -0
- aiagents4pharma/talk2biomodels/tests/test_integration.py +126 -0
- aiagents4pharma/talk2biomodels/tests/test_param_scan.py +68 -0
- aiagents4pharma/talk2biomodels/tests/test_query_article.py +76 -0
- aiagents4pharma/talk2biomodels/tests/test_search_models.py +28 -0
- aiagents4pharma/talk2biomodels/tests/test_simulate_model.py +39 -0
- aiagents4pharma/talk2biomodels/tests/test_steady_state.py +90 -0
- aiagents4pharma/talk2biomodels/tests/test_sys_bio_model.py +63 -0
- aiagents4pharma/talk2biomodels/tools/__init__.py +5 -0
- aiagents4pharma/talk2biomodels/tools/ask_question.py +61 -18
- aiagents4pharma/talk2biomodels/tools/custom_plotter.py +20 -14
- aiagents4pharma/talk2biomodels/tools/get_annotation.py +304 -0
- aiagents4pharma/talk2biomodels/tools/get_modelinfo.py +11 -9
- aiagents4pharma/talk2biomodels/tools/load_arguments.py +114 -0
- aiagents4pharma/talk2biomodels/tools/load_biomodel.py +0 -1
- aiagents4pharma/talk2biomodels/tools/parameter_scan.py +287 -0
- aiagents4pharma/talk2biomodels/tools/query_article.py +59 -0
- aiagents4pharma/talk2biomodels/tools/simulate_model.py +35 -90
- aiagents4pharma/talk2biomodels/tools/steady_state.py +167 -0
- aiagents4pharma/talk2cells/tests/scp_agent/test_scp_agent.py +23 -0
- aiagents4pharma/talk2cells/tools/scp_agent/__init__.py +6 -0
- aiagents4pharma/talk2cells/tools/scp_agent/display_studies.py +25 -0
- aiagents4pharma/talk2cells/tools/scp_agent/search_studies.py +79 -0
- aiagents4pharma/talk2competitors/__init__.py +5 -0
- aiagents4pharma/talk2competitors/agents/__init__.py +6 -0
- aiagents4pharma/talk2competitors/agents/main_agent.py +130 -0
- aiagents4pharma/talk2competitors/agents/s2_agent.py +75 -0
- aiagents4pharma/talk2competitors/config/__init__.py +5 -0
- aiagents4pharma/talk2competitors/config/config.py +110 -0
- aiagents4pharma/talk2competitors/state/__init__.py +5 -0
- aiagents4pharma/talk2competitors/state/state_talk2competitors.py +32 -0
- aiagents4pharma/talk2competitors/tests/__init__.py +3 -0
- aiagents4pharma/talk2competitors/tests/test_langgraph.py +274 -0
- aiagents4pharma/talk2competitors/tools/__init__.py +7 -0
- aiagents4pharma/talk2competitors/tools/s2/__init__.py +8 -0
- aiagents4pharma/talk2competitors/tools/s2/display_results.py +25 -0
- aiagents4pharma/talk2competitors/tools/s2/multi_paper_rec.py +132 -0
- aiagents4pharma/talk2competitors/tools/s2/search.py +119 -0
- aiagents4pharma/talk2competitors/tools/s2/single_paper_rec.py +141 -0
- aiagents4pharma/talk2knowledgegraphs/__init__.py +2 -1
- aiagents4pharma/talk2knowledgegraphs/tests/__init__.py +0 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_biobridge_primekg.py +242 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_dataset.py +29 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_primekg.py +73 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_starkqa_primekg.py +116 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_embeddings.py +47 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_huggingface.py +45 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_sentencetransformer.py +40 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_enrichments.py +39 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ollama.py +117 -0
- aiagents4pharma/talk2knowledgegraphs/utils/__init__.py +5 -0
- aiagents4pharma/talk2knowledgegraphs/utils/enrichments/__init__.py +5 -0
- aiagents4pharma/talk2knowledgegraphs/utils/enrichments/enrichments.py +36 -0
- aiagents4pharma/talk2knowledgegraphs/utils/enrichments/ollama.py +123 -0
- {aiagents4pharma-1.8.0.dist-info → aiagents4pharma-1.15.0.dist-info}/METADATA +44 -25
- aiagents4pharma-1.15.0.dist-info/RECORD +102 -0
- aiagents4pharma-1.8.0.dist-info/RECORD +0 -35
- {aiagents4pharma-1.8.0.dist-info → aiagents4pharma-1.15.0.dist-info}/LICENSE +0 -0
- {aiagents4pharma-1.8.0.dist-info → aiagents4pharma-1.15.0.dist-info}/WHEEL +0 -0
- {aiagents4pharma-1.8.0.dist-info → aiagents4pharma-1.15.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,141 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
|
3
|
+
"""
|
4
|
+
This tool is used to return recommendations for a single paper.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import logging
|
8
|
+
from typing import Annotated, Any, Dict, Optional
|
9
|
+
|
10
|
+
import pandas as pd
|
11
|
+
import requests
|
12
|
+
from langchain_core.messages import ToolMessage
|
13
|
+
from langchain_core.tools import tool
|
14
|
+
from langchain_core.tools.base import InjectedToolCallId
|
15
|
+
from langgraph.types import Command
|
16
|
+
from pydantic import BaseModel, Field
|
17
|
+
|
18
|
+
# Configure logging
|
19
|
+
logging.basicConfig(level=logging.INFO)
|
20
|
+
logger = logging.getLogger(__name__)
|
21
|
+
|
22
|
+
|
23
|
+
class SinglePaperRecInput(BaseModel):
|
24
|
+
"""Input schema for single paper recommendation tool."""
|
25
|
+
|
26
|
+
paper_id: str = Field(
|
27
|
+
description="Semantic Scholar Paper ID to get recommendations for (40-character string)"
|
28
|
+
)
|
29
|
+
limit: int = Field(
|
30
|
+
default=2,
|
31
|
+
description="Maximum number of recommendations to return",
|
32
|
+
ge=1,
|
33
|
+
le=500,
|
34
|
+
)
|
35
|
+
year: Optional[str] = Field(
|
36
|
+
default=None,
|
37
|
+
description="Year range in format: YYYY for specific year, "
|
38
|
+
"YYYY- for papers after year, -YYYY for papers before year, or YYYY:YYYY for range",
|
39
|
+
)
|
40
|
+
tool_call_id: Annotated[str, InjectedToolCallId]
|
41
|
+
model_config = {"arbitrary_types_allowed": True}
|
42
|
+
|
43
|
+
|
44
|
+
@tool(args_schema=SinglePaperRecInput)
|
45
|
+
def get_single_paper_recommendations(
|
46
|
+
paper_id: str,
|
47
|
+
tool_call_id: Annotated[str, InjectedToolCallId],
|
48
|
+
limit: int = 2,
|
49
|
+
year: Optional[str] = None,
|
50
|
+
) -> Dict[str, Any]:
|
51
|
+
"""
|
52
|
+
Get paper recommendations based on a single paper.
|
53
|
+
|
54
|
+
Args:
|
55
|
+
paper_id (str): The Semantic Scholar Paper ID to get recommendations for.
|
56
|
+
tool_call_id (Annotated[str, InjectedToolCallId]): The tool call ID.
|
57
|
+
limit (int, optional): The maximum number of recommendations to return. Defaults to 2.
|
58
|
+
year (str, optional): Year range for papers.
|
59
|
+
Supports formats like "2024-", "-2024", "2024:2025". Defaults to None.
|
60
|
+
|
61
|
+
Returns:
|
62
|
+
Dict[str, Any]: The recommendations and related information.
|
63
|
+
"""
|
64
|
+
logger.info("Starting single paper recommendations search.")
|
65
|
+
|
66
|
+
endpoint = (
|
67
|
+
f"https://api.semanticscholar.org/recommendations/v1/papers/forpaper/{paper_id}"
|
68
|
+
)
|
69
|
+
params = {
|
70
|
+
"limit": min(limit, 500), # Max 500 per API docs
|
71
|
+
"fields": "paperId,title,abstract,year,authors,citationCount,url",
|
72
|
+
"from": "all-cs", # Using all-cs pool as specified in docs
|
73
|
+
}
|
74
|
+
|
75
|
+
# Add year parameter if provided
|
76
|
+
if year:
|
77
|
+
params["year"] = year
|
78
|
+
|
79
|
+
response = requests.get(endpoint, params=params, timeout=10)
|
80
|
+
data = response.json()
|
81
|
+
papers = data.get("data", [])
|
82
|
+
response = requests.get(endpoint, params=params, timeout=10)
|
83
|
+
# print(f"API Response Status: {response.status_code}")
|
84
|
+
logging.info(
|
85
|
+
"API Response Status for recommendations of paper %s: %s",
|
86
|
+
paper_id,
|
87
|
+
response.status_code,
|
88
|
+
)
|
89
|
+
# print(f"Request params: {params}")
|
90
|
+
logging.info("Request params: %s", params)
|
91
|
+
|
92
|
+
data = response.json()
|
93
|
+
recommendations = data.get("recommendedPapers", [])
|
94
|
+
|
95
|
+
# Extract paper ID and title from recommendations
|
96
|
+
filtered_papers = {
|
97
|
+
paper["paperId"]: {
|
98
|
+
"Title": paper.get("title", "N/A"),
|
99
|
+
"Abstract": paper.get("abstract", "N/A"),
|
100
|
+
"Year": paper.get("year", "N/A"),
|
101
|
+
"Citation Count": paper.get("citationCount", "N/A"),
|
102
|
+
"URL": paper.get("url", "N/A"),
|
103
|
+
# "Publication Type": paper.get("publicationTypes", ["N/A"])[0]
|
104
|
+
# if paper.get("publicationTypes")
|
105
|
+
# else "N/A",
|
106
|
+
# "Open Access PDF": paper.get("openAccessPdf", {}).get("url", "N/A")
|
107
|
+
# if paper.get("openAccessPdf") is not None
|
108
|
+
# else "N/A",
|
109
|
+
}
|
110
|
+
for paper in recommendations
|
111
|
+
if paper.get("title") and paper.get("authors")
|
112
|
+
}
|
113
|
+
|
114
|
+
# Create a DataFrame for pretty printing
|
115
|
+
df = pd.DataFrame(filtered_papers)
|
116
|
+
|
117
|
+
# Format papers for state update
|
118
|
+
papers = [
|
119
|
+
f"Paper ID: {paper_id}\n"
|
120
|
+
f"Title: {paper_data['Title']}\n"
|
121
|
+
f"Abstract: {paper_data['Abstract']}\n"
|
122
|
+
f"Year: {paper_data['Year']}\n"
|
123
|
+
f"Citations: {paper_data['Citation Count']}\n"
|
124
|
+
f"URL: {paper_data['URL']}\n"
|
125
|
+
# f"Publication Type: {paper_data['Publication Type']}\n"
|
126
|
+
# f"Open Access PDF: {paper_data['Open Access PDF']}"
|
127
|
+
for paper_id, paper_data in filtered_papers.items()
|
128
|
+
]
|
129
|
+
|
130
|
+
# Convert DataFrame to markdown table
|
131
|
+
markdown_table = df.to_markdown(tablefmt="grid")
|
132
|
+
logging.info("Search results: %s", papers)
|
133
|
+
|
134
|
+
return Command(
|
135
|
+
update={
|
136
|
+
"papers": filtered_papers, # Now sending the dictionary directly
|
137
|
+
"messages": [
|
138
|
+
ToolMessage(content=markdown_table, tool_call_id=tool_call_id)
|
139
|
+
],
|
140
|
+
}
|
141
|
+
)
|
File without changes
|
@@ -0,0 +1,242 @@
|
|
1
|
+
"""
|
2
|
+
Test cases for datasets/primekg_loader.py
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
import shutil
|
7
|
+
import pytest
|
8
|
+
from ..datasets.biobridge_primekg import BioBridgePrimeKG
|
9
|
+
|
10
|
+
# Remove the data folder for testing if it exists
|
11
|
+
PRIMEKG_LOCAL_DIR = "../data/primekg_test/"
|
12
|
+
LOCAL_DIR = "../data/biobridge_primekg_test/"
|
13
|
+
shutil.rmtree(LOCAL_DIR, ignore_errors=True)
|
14
|
+
|
15
|
+
@pytest.fixture(name="biobridge_primekg")
|
16
|
+
def biobridge_primekg_fixture():
|
17
|
+
"""
|
18
|
+
Fixture for creating an instance of PrimeKG.
|
19
|
+
"""
|
20
|
+
return BioBridgePrimeKG(primekg_dir=PRIMEKG_LOCAL_DIR,
|
21
|
+
local_dir=LOCAL_DIR)
|
22
|
+
|
23
|
+
def test_download_primekg(biobridge_primekg):
|
24
|
+
"""
|
25
|
+
Test the loading method of the BioBridge-PrimeKG class by downloading data from repository.
|
26
|
+
"""
|
27
|
+
# Load BioBridge-PrimeKG data
|
28
|
+
biobridge_primekg.load_data()
|
29
|
+
primekg_nodes = biobridge_primekg.get_primekg().get_nodes()
|
30
|
+
primekg_edges = biobridge_primekg.get_primekg().get_edges()
|
31
|
+
biobridge_data_config = biobridge_primekg.get_data_config()
|
32
|
+
biobridge_emb_dict = biobridge_primekg.get_node_embeddings()
|
33
|
+
biobridge_triplets = biobridge_primekg.get_primekg_triplets()
|
34
|
+
biobridge_splits = biobridge_primekg.get_train_test_split()
|
35
|
+
biobridge_node_info = biobridge_primekg.get_node_info_dict()
|
36
|
+
|
37
|
+
# Check if the local directories exists
|
38
|
+
assert os.path.exists(biobridge_primekg.primekg_dir)
|
39
|
+
assert os.path.exists(biobridge_primekg.local_dir)
|
40
|
+
# Check if downloaded and processed files exist
|
41
|
+
# PrimeKG files
|
42
|
+
files = ["nodes.tab", "primekg_nodes.tsv.gz",
|
43
|
+
"edges.csv", "primekg_edges.tsv.gz"]
|
44
|
+
for file in files:
|
45
|
+
path = f"{biobridge_primekg.primekg_dir}/{file}"
|
46
|
+
assert os.path.exists(path)
|
47
|
+
# BioBridge data config
|
48
|
+
assert os.path.exists(f"{biobridge_primekg.local_dir}/data_config.json")
|
49
|
+
# BioBridge embeddings
|
50
|
+
files = [
|
51
|
+
"protein.pkl",
|
52
|
+
"mf.pkl",
|
53
|
+
"cc.pkl",
|
54
|
+
"bp.pkl",
|
55
|
+
"drug.pkl",
|
56
|
+
"disease.pkl",
|
57
|
+
"embedding_dict.pkl"
|
58
|
+
]
|
59
|
+
for file in files:
|
60
|
+
path = f"{biobridge_primekg.local_dir}/embeddings/{file}"
|
61
|
+
assert os.path.exists(path)
|
62
|
+
# BioBridge processed files
|
63
|
+
files = [
|
64
|
+
"protein.csv",
|
65
|
+
"mf.csv",
|
66
|
+
"cc.csv",
|
67
|
+
"bp.csv",
|
68
|
+
"drug.csv",
|
69
|
+
"disease.csv",
|
70
|
+
"triplet_full.tsv.gz",
|
71
|
+
"triplet_full_altered.tsv.gz",
|
72
|
+
"node_train.tsv.gz",
|
73
|
+
"triplet_train.tsv.gz",
|
74
|
+
"node_test.tsv.gz",
|
75
|
+
"triplet_test.tsv.gz",
|
76
|
+
]
|
77
|
+
for file in files:
|
78
|
+
path = f"{biobridge_primekg.local_dir}/processed/{file}"
|
79
|
+
assert os.path.exists(path)
|
80
|
+
# Check processed PrimeKG dataframes
|
81
|
+
# Nodes
|
82
|
+
assert primekg_nodes is not None
|
83
|
+
assert len(primekg_nodes) > 0
|
84
|
+
assert primekg_nodes.shape[0] == 129375
|
85
|
+
# Edges
|
86
|
+
assert primekg_edges is not None
|
87
|
+
assert len(primekg_edges) > 0
|
88
|
+
assert primekg_edges.shape[0] == 8100498
|
89
|
+
# Check processed BioBridge data config
|
90
|
+
assert biobridge_data_config is not None
|
91
|
+
assert len(biobridge_data_config) > 0
|
92
|
+
assert len(biobridge_data_config['node_type']) == 10
|
93
|
+
assert len(biobridge_data_config['relation_type']) == 18
|
94
|
+
assert len(biobridge_data_config['emb_dim']) == 6
|
95
|
+
# Check processed BioBridge embeddings
|
96
|
+
assert biobridge_emb_dict is not None
|
97
|
+
assert len(biobridge_emb_dict) > 0
|
98
|
+
assert len(biobridge_emb_dict) == 85466
|
99
|
+
# Check processed BioBridge triplets
|
100
|
+
assert biobridge_triplets is not None
|
101
|
+
assert len(biobridge_triplets) > 0
|
102
|
+
assert biobridge_triplets.shape[0] == 3904610
|
103
|
+
assert list(biobridge_splits.keys()) == ['train', 'node_train', 'test', 'node_test']
|
104
|
+
assert len(biobridge_splits['train']) == 3510930
|
105
|
+
assert len(biobridge_splits['node_train']) == 76486
|
106
|
+
assert len(biobridge_splits['test']) == 393680
|
107
|
+
assert len(biobridge_splits['node_test']) == 8495
|
108
|
+
# Check node info dictionary
|
109
|
+
assert list(biobridge_node_info.keys()) == ['gene/protein',
|
110
|
+
'molecular_function',
|
111
|
+
'cellular_component',
|
112
|
+
'biological_process',
|
113
|
+
'drug',
|
114
|
+
'disease']
|
115
|
+
assert len(biobridge_node_info['gene/protein']) == 19162
|
116
|
+
assert len(biobridge_node_info['molecular_function']) == 10966
|
117
|
+
assert len(biobridge_node_info['cellular_component']) == 4013
|
118
|
+
assert len(biobridge_node_info['biological_process']) == 27478
|
119
|
+
assert len(biobridge_node_info['drug']) == 6948
|
120
|
+
assert len(biobridge_node_info['disease']) == 44133
|
121
|
+
|
122
|
+
|
123
|
+
def test_load_existing_primekg(biobridge_primekg):
|
124
|
+
"""
|
125
|
+
Test the loading method of the BioBridge-PrimeKG class by loading existing data in local.
|
126
|
+
"""
|
127
|
+
# Load BioBridge-PrimeKG data
|
128
|
+
biobridge_primekg.load_data()
|
129
|
+
primekg_nodes = biobridge_primekg.get_primekg().get_nodes()
|
130
|
+
primekg_edges = biobridge_primekg.get_primekg().get_edges()
|
131
|
+
biobridge_data_config = biobridge_primekg.get_data_config()
|
132
|
+
biobridge_emb_dict = biobridge_primekg.get_node_embeddings()
|
133
|
+
biobridge_triplets = biobridge_primekg.get_primekg_triplets()
|
134
|
+
biobridge_splits = biobridge_primekg.get_train_test_split()
|
135
|
+
biobridge_node_info = biobridge_primekg.get_node_info_dict()
|
136
|
+
|
137
|
+
# Check if the local directories exists
|
138
|
+
assert os.path.exists(biobridge_primekg.primekg_dir)
|
139
|
+
assert os.path.exists(biobridge_primekg.local_dir)
|
140
|
+
# Check if downloaded and processed files exist
|
141
|
+
# PrimeKG files
|
142
|
+
files = ["nodes.tab", "primekg_nodes.tsv.gz",
|
143
|
+
"edges.csv", "primekg_edges.tsv.gz"]
|
144
|
+
for file in files:
|
145
|
+
path = f"{biobridge_primekg.primekg_dir}/{file}"
|
146
|
+
assert os.path.exists(path)
|
147
|
+
# BioBridge data config
|
148
|
+
assert os.path.exists(f"{biobridge_primekg.local_dir}/data_config.json")
|
149
|
+
# BioBridge embeddings
|
150
|
+
files = [
|
151
|
+
"protein.pkl",
|
152
|
+
"mf.pkl",
|
153
|
+
"cc.pkl",
|
154
|
+
"bp.pkl",
|
155
|
+
"drug.pkl",
|
156
|
+
"disease.pkl",
|
157
|
+
"embedding_dict.pkl"
|
158
|
+
]
|
159
|
+
for file in files:
|
160
|
+
path = f"{biobridge_primekg.local_dir}/embeddings/{file}"
|
161
|
+
assert os.path.exists(path)
|
162
|
+
# BioBridge processed files
|
163
|
+
files = [
|
164
|
+
"protein.csv",
|
165
|
+
"mf.csv",
|
166
|
+
"cc.csv",
|
167
|
+
"bp.csv",
|
168
|
+
"drug.csv",
|
169
|
+
"disease.csv",
|
170
|
+
"triplet_full.tsv.gz",
|
171
|
+
"triplet_full_altered.tsv.gz",
|
172
|
+
"node_train.tsv.gz",
|
173
|
+
"triplet_train.tsv.gz",
|
174
|
+
"node_test.tsv.gz",
|
175
|
+
"triplet_test.tsv.gz",
|
176
|
+
]
|
177
|
+
for file in files:
|
178
|
+
path = f"{biobridge_primekg.local_dir}/processed/{file}"
|
179
|
+
assert os.path.exists(path)
|
180
|
+
# Check processed PrimeKG dataframes
|
181
|
+
# Nodes
|
182
|
+
assert primekg_nodes is not None
|
183
|
+
assert len(primekg_nodes) > 0
|
184
|
+
assert primekg_nodes.shape[0] == 129375
|
185
|
+
# Edges
|
186
|
+
assert primekg_edges is not None
|
187
|
+
assert len(primekg_edges) > 0
|
188
|
+
assert primekg_edges.shape[0] == 8100498
|
189
|
+
# Check processed BioBridge data config
|
190
|
+
assert biobridge_data_config is not None
|
191
|
+
assert len(biobridge_data_config) > 0
|
192
|
+
assert len(biobridge_data_config['node_type']) == 10
|
193
|
+
assert len(biobridge_data_config['relation_type']) == 18
|
194
|
+
assert len(biobridge_data_config['emb_dim']) == 6
|
195
|
+
# Check processed BioBridge embeddings
|
196
|
+
assert biobridge_emb_dict is not None
|
197
|
+
assert len(biobridge_emb_dict) > 0
|
198
|
+
assert len(biobridge_emb_dict) == 85466
|
199
|
+
# Check processed BioBridge triplets
|
200
|
+
assert biobridge_triplets is not None
|
201
|
+
assert len(biobridge_triplets) > 0
|
202
|
+
assert biobridge_triplets.shape[0] == 3904610
|
203
|
+
assert list(biobridge_splits.keys()) == ['train', 'node_train', 'test', 'node_test']
|
204
|
+
assert len(biobridge_splits['train']) == 3510930
|
205
|
+
assert len(biobridge_splits['node_train']) == 76486
|
206
|
+
assert len(biobridge_splits['test']) == 393680
|
207
|
+
assert len(biobridge_splits['node_test']) == 8495
|
208
|
+
# Check node info dictionary
|
209
|
+
assert list(biobridge_node_info.keys()) == ['gene/protein',
|
210
|
+
'molecular_function',
|
211
|
+
'cellular_component',
|
212
|
+
'biological_process',
|
213
|
+
'drug',
|
214
|
+
'disease']
|
215
|
+
assert len(biobridge_node_info['gene/protein']) == 19162
|
216
|
+
assert len(biobridge_node_info['molecular_function']) == 10966
|
217
|
+
assert len(biobridge_node_info['cellular_component']) == 4013
|
218
|
+
assert len(biobridge_node_info['biological_process']) == 27478
|
219
|
+
assert len(biobridge_node_info['drug']) == 6948
|
220
|
+
assert len(biobridge_node_info['disease']) == 44133
|
221
|
+
|
222
|
+
# def test_load_existing_primekg_with_negative_triplets(biobridge_primekg):
|
223
|
+
# """
|
224
|
+
# Test the loading method of the BioBridge-PrimeKG class by loading existing data in local.
|
225
|
+
# In addition, it builds negative triplets for training data.
|
226
|
+
# """
|
227
|
+
# # Load BioBridge-PrimeKG data
|
228
|
+
# # Using 1 negative sample per positive triplet
|
229
|
+
# biobridge_primekg.load_data(build_neg_triplest=True, n_neg_samples=1)
|
230
|
+
# biobridge_neg_triplets = biobridge_primekg.get_primekg_triplets_negative()
|
231
|
+
|
232
|
+
# # Check if the local directories exists
|
233
|
+
# assert os.path.exists(biobridge_primekg.primekg_dir)
|
234
|
+
# assert os.path.exists(biobridge_primekg.local_dir)
|
235
|
+
# # Check if downloaded and processed files exist
|
236
|
+
# path = f"{biobridge_primekg.local_dir}/processed/triplet_train_negative.tsv.gz"
|
237
|
+
# assert os.path.exists(path)
|
238
|
+
# # Check processed BioBridge triplets
|
239
|
+
# assert biobridge_neg_triplets is not None
|
240
|
+
# assert len(biobridge_neg_triplets) > 0
|
241
|
+
# assert biobridge_neg_triplets.shape[0] == 3510930
|
242
|
+
# assert len(biobridge_neg_triplets.negative_tail_index[0]) == 1
|
@@ -0,0 +1,29 @@
|
|
1
|
+
"""
|
2
|
+
Test cases for datasets/dataset.py
|
3
|
+
"""
|
4
|
+
|
5
|
+
from ..datasets.dataset import Dataset
|
6
|
+
|
7
|
+
class MockDataset(Dataset):
|
8
|
+
"""
|
9
|
+
Mock dataset class for testing purposes.
|
10
|
+
"""
|
11
|
+
def setup(self):
|
12
|
+
pass
|
13
|
+
|
14
|
+
def load_data(self):
|
15
|
+
pass
|
16
|
+
|
17
|
+
def test_dataset_setup():
|
18
|
+
"""
|
19
|
+
Test the setup method of the Dataset class.
|
20
|
+
"""
|
21
|
+
dataset = MockDataset()
|
22
|
+
assert dataset.setup() is None
|
23
|
+
|
24
|
+
def test_dataset_load_data():
|
25
|
+
"""
|
26
|
+
Test the load_data method of the Dataset class.
|
27
|
+
"""
|
28
|
+
dataset = MockDataset()
|
29
|
+
assert dataset.load_data() is None
|
@@ -0,0 +1,73 @@
|
|
1
|
+
"""
|
2
|
+
Test cases for datasets/primekg_loader.py
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
import shutil
|
7
|
+
import pytest
|
8
|
+
from ..datasets.primekg import PrimeKG
|
9
|
+
|
10
|
+
# Remove the data folder for testing if it exists
|
11
|
+
LOCAL_DIR = "../data/primekg_test/"
|
12
|
+
shutil.rmtree(LOCAL_DIR, ignore_errors=True)
|
13
|
+
|
14
|
+
@pytest.fixture(name="primekg")
|
15
|
+
def primekg_fixture():
|
16
|
+
"""
|
17
|
+
Fixture for creating an instance of PrimeKG.
|
18
|
+
"""
|
19
|
+
return PrimeKG(local_dir=LOCAL_DIR)
|
20
|
+
|
21
|
+
def test_download_primekg(primekg):
|
22
|
+
"""
|
23
|
+
Test the loading method of the PrimeKG class by downloading PrimeKG from server.
|
24
|
+
"""
|
25
|
+
# Load PrimeKG data
|
26
|
+
primekg.load_data()
|
27
|
+
primekg_nodes = primekg.get_nodes()
|
28
|
+
primekg_edges = primekg.get_edges()
|
29
|
+
|
30
|
+
# Check if the local directory exists
|
31
|
+
assert os.path.exists(primekg.local_dir)
|
32
|
+
# Check if downloaded and processed files exist
|
33
|
+
files = ["nodes.tab", f"{primekg.name}_nodes.tsv.gz",
|
34
|
+
"edges.csv", f"{primekg.name}_edges.tsv.gz"]
|
35
|
+
for file in files:
|
36
|
+
path = f"{primekg.local_dir}/{file}"
|
37
|
+
assert os.path.exists(path)
|
38
|
+
# Check processed PrimeKG dataframes
|
39
|
+
# Nodes
|
40
|
+
assert primekg_nodes is not None
|
41
|
+
assert len(primekg_nodes) > 0
|
42
|
+
assert primekg_nodes.shape[0] == 129375
|
43
|
+
# Edges
|
44
|
+
assert primekg_edges is not None
|
45
|
+
assert len(primekg_edges) > 0
|
46
|
+
assert primekg_edges.shape[0] == 8100498
|
47
|
+
|
48
|
+
def test_load_existing_primekg(primekg):
|
49
|
+
"""
|
50
|
+
Test the loading method of the PrimeKG class by loading existing PrimeKG in local.
|
51
|
+
"""
|
52
|
+
# Load PrimeKG data
|
53
|
+
primekg.load_data()
|
54
|
+
primekg_nodes = primekg.get_nodes()
|
55
|
+
primekg_edges = primekg.get_edges()
|
56
|
+
|
57
|
+
# Check if the local directory exists
|
58
|
+
assert os.path.exists(primekg.local_dir)
|
59
|
+
# Check if downloaded and processed files exist
|
60
|
+
files = ["nodes.tab", f"{primekg.name}_nodes.tsv.gz",
|
61
|
+
"edges.csv", f"{primekg.name}_edges.tsv.gz"]
|
62
|
+
for file in files:
|
63
|
+
path = f"{primekg.local_dir}/{file}"
|
64
|
+
assert os.path.exists(path)
|
65
|
+
# Check processed PrimeKG dataframes
|
66
|
+
# Nodes
|
67
|
+
assert primekg_nodes is not None
|
68
|
+
assert len(primekg_nodes) > 0
|
69
|
+
assert primekg_nodes.shape[0] == 129375
|
70
|
+
# Edges
|
71
|
+
assert primekg_edges is not None
|
72
|
+
assert len(primekg_edges) > 0
|
73
|
+
assert primekg_edges.shape[0] == 8100498
|
@@ -0,0 +1,116 @@
|
|
1
|
+
"""
|
2
|
+
Test cases for datasets/starkqa_primekg_loader.py
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
import shutil
|
7
|
+
import pytest
|
8
|
+
from ..datasets.starkqa_primekg import StarkQAPrimeKG
|
9
|
+
|
10
|
+
# Remove the data folder for testing if it exists
|
11
|
+
LOCAL_DIR = "../data/starkqa_primekg_test/"
|
12
|
+
shutil.rmtree(LOCAL_DIR, ignore_errors=True)
|
13
|
+
|
14
|
+
@pytest.fixture(name="starkqa_primekg")
|
15
|
+
def starkqa_primekg_fixture():
|
16
|
+
"""
|
17
|
+
Fixture for creating an instance of StarkQAPrimeKGData.
|
18
|
+
"""
|
19
|
+
return StarkQAPrimeKG(local_dir=LOCAL_DIR)
|
20
|
+
|
21
|
+
def test_download_starkqa_primekg(starkqa_primekg):
|
22
|
+
"""
|
23
|
+
Test the loading method of the StarkQAPrimeKGLoaderTool class by downloading files
|
24
|
+
from HuggingFace Hub.
|
25
|
+
"""
|
26
|
+
# Load StarkQA PrimeKG data
|
27
|
+
starkqa_primekg.load_data()
|
28
|
+
starkqa_df = starkqa_primekg.get_starkqa()
|
29
|
+
primekg_node_info = starkqa_primekg.get_starkqa_node_info()
|
30
|
+
split_idx = starkqa_primekg.get_starkqa_split_indicies()
|
31
|
+
query_embeddings = starkqa_primekg.get_query_embeddings()
|
32
|
+
node_embeddings = starkqa_primekg.get_node_embeddings()
|
33
|
+
|
34
|
+
# Check if the local directory exists
|
35
|
+
assert os.path.exists(starkqa_primekg.local_dir)
|
36
|
+
# Check if downloaded files exist in the local directory
|
37
|
+
files = ['qa/prime/split/test-0.1.index',
|
38
|
+
'qa/prime/split/test.index',
|
39
|
+
'qa/prime/split/train.index',
|
40
|
+
'qa/prime/split/val.index',
|
41
|
+
'qa/prime/stark_qa/stark_qa.csv',
|
42
|
+
'qa/prime/stark_qa/stark_qa_human_generated_eval.csv',
|
43
|
+
'skb/prime/processed.zip']
|
44
|
+
for file in files:
|
45
|
+
path = f"{starkqa_primekg.local_dir}/{file}"
|
46
|
+
assert os.path.exists(path)
|
47
|
+
# Check dataframe
|
48
|
+
assert starkqa_df is not None
|
49
|
+
assert len(starkqa_df) > 0
|
50
|
+
assert starkqa_df.shape[0] == 11204
|
51
|
+
# Check node information
|
52
|
+
assert primekg_node_info is not None
|
53
|
+
assert len(primekg_node_info) == 129375
|
54
|
+
# Check split indices
|
55
|
+
assert list(split_idx.keys()) == ['train', 'val', 'test', 'test-0.1']
|
56
|
+
assert len(split_idx['train']) == 6162
|
57
|
+
assert len(split_idx['val']) == 2241
|
58
|
+
assert len(split_idx['test']) == 2801
|
59
|
+
assert len(split_idx['test-0.1']) == 280
|
60
|
+
# Check query embeddings
|
61
|
+
assert query_embeddings is not None
|
62
|
+
assert len(query_embeddings) == 11204
|
63
|
+
assert query_embeddings[0].shape[1] == 1536
|
64
|
+
# Check node embeddings
|
65
|
+
assert node_embeddings is not None
|
66
|
+
assert len(node_embeddings) == 129375
|
67
|
+
assert node_embeddings[0].shape[1] == 1536
|
68
|
+
|
69
|
+
def test_load_existing_starkqa_primekg(starkqa_primekg):
|
70
|
+
"""
|
71
|
+
|
72
|
+
Test the loading method of the StarkQAPrimeKGLoaderTool class by loading existing files
|
73
|
+
in the local directory.
|
74
|
+
"""
|
75
|
+
# Load StarkQA PrimeKG data
|
76
|
+
starkqa_primekg.load_data()
|
77
|
+
starkqa_df = starkqa_primekg.get_starkqa()
|
78
|
+
primekg_node_info = starkqa_primekg.get_starkqa_node_info()
|
79
|
+
split_idx = starkqa_primekg.get_starkqa_split_indicies()
|
80
|
+
query_embeddings = starkqa_primekg.get_query_embeddings()
|
81
|
+
node_embeddings = starkqa_primekg.get_node_embeddings()
|
82
|
+
|
83
|
+
# Check if the local directory exists
|
84
|
+
assert os.path.exists(starkqa_primekg.local_dir)
|
85
|
+
# Check if downloaded and processed files exist
|
86
|
+
files = ['qa/prime/split/test-0.1.index',
|
87
|
+
'qa/prime/split/test.index',
|
88
|
+
'qa/prime/split/train.index',
|
89
|
+
'qa/prime/split/val.index',
|
90
|
+
'qa/prime/stark_qa/stark_qa.csv',
|
91
|
+
'qa/prime/stark_qa/stark_qa_human_generated_eval.csv',
|
92
|
+
'skb/prime/processed.zip']
|
93
|
+
for file in files:
|
94
|
+
path = f"{starkqa_primekg.local_dir}/{file}"
|
95
|
+
assert os.path.exists(path)
|
96
|
+
# Check dataframe
|
97
|
+
assert starkqa_df is not None
|
98
|
+
assert len(starkqa_df) > 0
|
99
|
+
assert starkqa_df.shape[0] == 11204
|
100
|
+
# Check node information
|
101
|
+
assert primekg_node_info is not None
|
102
|
+
assert len(primekg_node_info) == 129375
|
103
|
+
# Check split indices
|
104
|
+
assert list(split_idx.keys()) == ['train', 'val', 'test', 'test-0.1']
|
105
|
+
assert len(split_idx['train']) == 6162
|
106
|
+
assert len(split_idx['val']) == 2241
|
107
|
+
assert len(split_idx['test']) == 2801
|
108
|
+
assert len(split_idx['test-0.1']) == 280
|
109
|
+
# Check query embeddings
|
110
|
+
assert query_embeddings is not None
|
111
|
+
assert len(query_embeddings) == 11204
|
112
|
+
assert query_embeddings[0].shape[1] == 1536
|
113
|
+
# Check node embeddings
|
114
|
+
assert node_embeddings is not None
|
115
|
+
assert len(node_embeddings) == 129375
|
116
|
+
assert node_embeddings[0].shape[1] == 1536
|
@@ -0,0 +1,47 @@
|
|
1
|
+
"""
|
2
|
+
Test cases for utils/embeddings/embeddings.py
|
3
|
+
"""
|
4
|
+
|
5
|
+
import pytest
|
6
|
+
from ..utils.embeddings.embeddings import Embeddings
|
7
|
+
|
8
|
+
class TestEmbeddings(Embeddings):
|
9
|
+
"""Test implementation of the Embeddings interface for testing purposes."""
|
10
|
+
|
11
|
+
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
12
|
+
return [[0.1, 0.2, 0.3] for _ in texts]
|
13
|
+
|
14
|
+
def embed_query(self, text: str) -> list[float]:
|
15
|
+
return [0.1, 0.2, 0.3]
|
16
|
+
|
17
|
+
def test_embed_documents():
|
18
|
+
"""Test embedding documents using the Embeddings interface."""
|
19
|
+
embeddings = TestEmbeddings()
|
20
|
+
texts = ["text1", "text2"]
|
21
|
+
result = embeddings.embed_documents(texts)
|
22
|
+
assert result == [[0.1, 0.2, 0.3], [0.1, 0.2, 0.3]]
|
23
|
+
|
24
|
+
|
25
|
+
def test_embed_query():
|
26
|
+
"""Test embedding a query using the Embeddings interface."""
|
27
|
+
embeddings = TestEmbeddings()
|
28
|
+
text = "query"
|
29
|
+
result = embeddings.embed_query(text)
|
30
|
+
assert result == [0.1, 0.2, 0.3]
|
31
|
+
|
32
|
+
@pytest.mark.asyncio
|
33
|
+
async def test_aembed_documents():
|
34
|
+
"""Test asynchronous embedding of documents using the Embeddings interface."""
|
35
|
+
embeddings = TestEmbeddings()
|
36
|
+
texts = ["text1", "text2"]
|
37
|
+
result = await embeddings.aembed_documents(texts)
|
38
|
+
assert result == [[0.1, 0.2, 0.3], [0.1, 0.2, 0.3]]
|
39
|
+
|
40
|
+
|
41
|
+
@pytest.mark.asyncio
|
42
|
+
async def test_aembed_query():
|
43
|
+
"""Test asynchronous embedding of a query using the Embeddings interface."""
|
44
|
+
embeddings = TestEmbeddings()
|
45
|
+
text = "query"
|
46
|
+
result = await embeddings.aembed_query(text)
|
47
|
+
assert result == [0.1, 0.2, 0.3]
|