aiagents4pharma 1.8.2__py3-none-any.whl → 1.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiagents4pharma/__init__.py +1 -0
- aiagents4pharma/configs/__init__.py +5 -0
- aiagents4pharma/configs/config.yaml +3 -0
- aiagents4pharma/configs/talk2biomodels/__init__.py +5 -0
- aiagents4pharma/configs/talk2biomodels/agents/__init__.py +5 -0
- aiagents4pharma/configs/talk2biomodels/agents/t2b_agent/__init__.py +3 -0
- aiagents4pharma/configs/talk2biomodels/agents/t2b_agent/default.yaml +6 -0
- aiagents4pharma/talk2biomodels/agents/t2b_agent.py +3 -3
- aiagents4pharma/talk2biomodels/tests/__init__.py +3 -0
- aiagents4pharma/talk2biomodels/tests/test_basico_model.py +55 -0
- aiagents4pharma/talk2biomodels/tests/test_langgraph.py +189 -0
- aiagents4pharma/talk2biomodels/tests/test_sys_bio_model.py +57 -0
- aiagents4pharma/talk2cells/tests/scp_agent/test_scp_agent.py +23 -0
- aiagents4pharma/talk2competitors/__init__.py +0 -0
- aiagents4pharma/talk2knowledgegraphs/tests/__init__.py +0 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_biobridge_primekg.py +242 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_dataset.py +29 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_primekg.py +73 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_starkqa_primekg.py +116 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_embeddings.py +47 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_huggingface.py +45 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_sentencetransformer.py +40 -0
- {aiagents4pharma-1.8.2.dist-info → aiagents4pharma-1.8.3.dist-info}/METADATA +1 -1
- {aiagents4pharma-1.8.2.dist-info → aiagents4pharma-1.8.3.dist-info}/RECORD +27 -7
- {aiagents4pharma-1.8.2.dist-info → aiagents4pharma-1.8.3.dist-info}/LICENSE +0 -0
- {aiagents4pharma-1.8.2.dist-info → aiagents4pharma-1.8.3.dist-info}/WHEEL +0 -0
- {aiagents4pharma-1.8.2.dist-info → aiagents4pharma-1.8.3.dist-info}/top_level.txt +0 -0
aiagents4pharma/__init__.py
CHANGED
@@ -52,10 +52,10 @@ def get_app(uniq_id, llm_model='gpt-4o-mini'):
|
|
52
52
|
llm = ChatOpenAI(model=llm_model, temperature=0)
|
53
53
|
# Load hydra configuration
|
54
54
|
logger.log(logging.INFO, "Load Hydra configuration for Talk2BioModels agent.")
|
55
|
-
with hydra.initialize(version_base=None, config_path="
|
55
|
+
with hydra.initialize(version_base=None, config_path="../../configs"):
|
56
56
|
cfg = hydra.compose(config_name='config',
|
57
|
-
overrides=['
|
58
|
-
cfg = cfg.
|
57
|
+
overrides=['talk2biomodels/agents/t2b_agent=default'])
|
58
|
+
cfg = cfg.talk2biomodels.agents.t2b_agent
|
59
59
|
logger.log(logging.INFO, "state_modifier: %s", cfg.state_modifier)
|
60
60
|
# Create the agent
|
61
61
|
model = create_react_agent(
|
@@ -0,0 +1,55 @@
|
|
1
|
+
'''
|
2
|
+
A test BasicoModel class for pytest unit testing.
|
3
|
+
'''
|
4
|
+
|
5
|
+
import pandas as pd
|
6
|
+
import pytest
|
7
|
+
import basico
|
8
|
+
from ..models.basico_model import BasicoModel
|
9
|
+
|
10
|
+
@pytest.fixture(name="model")
|
11
|
+
def model_fixture():
|
12
|
+
"""
|
13
|
+
A fixture for the BasicoModel class.
|
14
|
+
"""
|
15
|
+
return BasicoModel(biomodel_id=64, species={"Pyruvate": 100}, duration=2, interval=2)
|
16
|
+
|
17
|
+
def test_with_biomodel_id(model):
|
18
|
+
"""
|
19
|
+
Test initialization of BasicoModel with biomodel_id.
|
20
|
+
"""
|
21
|
+
assert model.biomodel_id == 64
|
22
|
+
# check if the simulation results are a pandas DataFrame object
|
23
|
+
assert isinstance(model.simulate(parameters={'Pyruvate': 0.5, 'KmPFKF6P': 1.5},
|
24
|
+
duration=2,
|
25
|
+
interval=2),
|
26
|
+
pd.DataFrame)
|
27
|
+
assert isinstance(model.simulate(parameters={None: None}, duration=2, interval=2),
|
28
|
+
pd.DataFrame)
|
29
|
+
assert model.description == basico.biomodels.get_model_info(model.biomodel_id)["description"]
|
30
|
+
|
31
|
+
def test_with_sbml_file():
|
32
|
+
"""
|
33
|
+
Test initialization of BasicoModel with sbml_file_path.
|
34
|
+
"""
|
35
|
+
model_object = BasicoModel(sbml_file_path="./BIOMD0000000064_url.xml")
|
36
|
+
assert model_object.sbml_file_path == "./BIOMD0000000064_url.xml"
|
37
|
+
assert isinstance(model_object.simulate(duration=2, interval=2), pd.DataFrame)
|
38
|
+
assert isinstance(model_object.simulate(parameters={'NADH': 0.5}, duration=2, interval=2),
|
39
|
+
pd.DataFrame)
|
40
|
+
|
41
|
+
def test_check_biomodel_id_or_sbml_file_path():
|
42
|
+
'''
|
43
|
+
Test the check_biomodel_id_or_sbml_file_path method of the BioModel class.
|
44
|
+
'''
|
45
|
+
with pytest.raises(ValueError):
|
46
|
+
BasicoModel(species={"Pyruvate": 100}, duration=2, interval=2)
|
47
|
+
|
48
|
+
def test_get_model_metadata():
|
49
|
+
"""
|
50
|
+
Test the get_model_metadata method of the BasicoModel class.
|
51
|
+
"""
|
52
|
+
model = BasicoModel(biomodel_id=64)
|
53
|
+
metadata = model.get_model_metadata()
|
54
|
+
assert metadata["Model Type"] == "SBML Model (COPASI)"
|
55
|
+
assert metadata["Parameter Count"] == len(basico.get_parameters())
|
@@ -0,0 +1,189 @@
|
|
1
|
+
'''
|
2
|
+
Test cases
|
3
|
+
'''
|
4
|
+
|
5
|
+
from langchain_core.messages import HumanMessage, ToolMessage
|
6
|
+
from ..agents.t2b_agent import get_app
|
7
|
+
|
8
|
+
def test_get_modelinfo_tool():
|
9
|
+
'''
|
10
|
+
Test the get_modelinfo tool.
|
11
|
+
'''
|
12
|
+
unique_id = 12345
|
13
|
+
app = get_app(unique_id)
|
14
|
+
config = {"configurable": {"thread_id": unique_id}}
|
15
|
+
# Update state
|
16
|
+
app.update_state(config,{"sbml_file_path": ["BIOMD0000000537.xml"]})
|
17
|
+
prompt = "Extract all relevant information from the uploaded model."
|
18
|
+
# Test the tool get_modelinfo
|
19
|
+
response = app.invoke(
|
20
|
+
{"messages": [HumanMessage(content=prompt)]},
|
21
|
+
config=config
|
22
|
+
)
|
23
|
+
assistant_msg = response["messages"][-1].content
|
24
|
+
# Check if the assistant message is a string
|
25
|
+
assert isinstance(assistant_msg, str)
|
26
|
+
|
27
|
+
def test_search_models_tool():
|
28
|
+
'''
|
29
|
+
Test the search_models tool.
|
30
|
+
'''
|
31
|
+
unique_id = 12345
|
32
|
+
app = get_app(unique_id)
|
33
|
+
config = {"configurable": {"thread_id": unique_id}}
|
34
|
+
# Update state
|
35
|
+
app.update_state(config, {"llm_model": "gpt-4o-mini"})
|
36
|
+
prompt = "Search for models on Crohn's disease."
|
37
|
+
# Test the tool get_modelinfo
|
38
|
+
response = app.invoke(
|
39
|
+
{"messages": [HumanMessage(content=prompt)]},
|
40
|
+
config=config
|
41
|
+
)
|
42
|
+
assistant_msg = response["messages"][-1].content
|
43
|
+
# Check if the assistant message is a string
|
44
|
+
assert isinstance(assistant_msg, str)
|
45
|
+
# Check if the assistant message contains the
|
46
|
+
# biomodel id BIO0000000537
|
47
|
+
assert "BIOMD0000000537" in assistant_msg
|
48
|
+
|
49
|
+
def test_ask_question_tool():
|
50
|
+
'''
|
51
|
+
Test the ask_question tool without the simulation results.
|
52
|
+
'''
|
53
|
+
unique_id = 12345
|
54
|
+
app = get_app(unique_id, llm_model='gpt-4o-mini')
|
55
|
+
config = {"configurable": {"thread_id": unique_id}}
|
56
|
+
|
57
|
+
##########################################
|
58
|
+
# Test ask_question tool when simulation
|
59
|
+
# results are not available
|
60
|
+
##########################################
|
61
|
+
# Update state
|
62
|
+
app.update_state(config, {"llm_model": "gpt-4o-mini"})
|
63
|
+
prompt = "Call the ask_question tool to answer the "
|
64
|
+
prompt += "question: What is the concentration of CRP "
|
65
|
+
prompt += "in serum at 1000 hours?"
|
66
|
+
|
67
|
+
# Test the tool get_modelinfo
|
68
|
+
response = app.invoke(
|
69
|
+
{"messages": [HumanMessage(content=prompt)]},
|
70
|
+
config=config
|
71
|
+
)
|
72
|
+
assistant_msg = response["messages"][-1].content
|
73
|
+
# Check if the assistant message is a string
|
74
|
+
assert isinstance(assistant_msg, str)
|
75
|
+
|
76
|
+
def test_simulate_model_tool():
|
77
|
+
'''
|
78
|
+
Test the simulate_model tool.
|
79
|
+
'''
|
80
|
+
unique_id = 123
|
81
|
+
app = get_app(unique_id)
|
82
|
+
config = {"configurable": {"thread_id": unique_id}}
|
83
|
+
app.update_state(config, {"llm_model": "gpt-4o-mini"})
|
84
|
+
# ##########################################
|
85
|
+
# ## Test simulate_model tool
|
86
|
+
# ##########################################
|
87
|
+
prompt = "Simulate the model 537 for 2016 hours and intervals"
|
88
|
+
prompt += " 2016 with an initial concentration of `DoseQ2W` "
|
89
|
+
prompt += "set to 300 and `Dose` set to 0. Reset the concentration"
|
90
|
+
prompt += " of `NAD` to 100 every 500 hours."
|
91
|
+
# Test the tool get_modelinfo
|
92
|
+
response = app.invoke(
|
93
|
+
{"messages": [HumanMessage(content=prompt)]},
|
94
|
+
config=config
|
95
|
+
)
|
96
|
+
assistant_msg = response["messages"][-1].content
|
97
|
+
print (assistant_msg)
|
98
|
+
# Check if the assistant message is a string
|
99
|
+
assert isinstance(assistant_msg, str)
|
100
|
+
##########################################
|
101
|
+
# Test ask_question tool when simulation
|
102
|
+
# results are available
|
103
|
+
##########################################
|
104
|
+
# Update state
|
105
|
+
app.update_state(config, {"llm_model": "gpt-4o-mini"})
|
106
|
+
prompt = "What is the concentration of CRP in serum at 1000 hours? "
|
107
|
+
# prompt += "Show only the concentration, rounded to 1 decimal place."
|
108
|
+
# prompt += "For example, if the concentration is 0.123456, "
|
109
|
+
# prompt += "your response should be `0.1`. Do not return any other information."
|
110
|
+
# Test the tool get_modelinfo
|
111
|
+
response = app.invoke(
|
112
|
+
{"messages": [HumanMessage(content=prompt)]},
|
113
|
+
config=config
|
114
|
+
)
|
115
|
+
assistant_msg = response["messages"][-1].content
|
116
|
+
# print (assistant_msg)
|
117
|
+
# Check if the assistant message is a string
|
118
|
+
assert "1.7" in assistant_msg
|
119
|
+
|
120
|
+
##########################################
|
121
|
+
# Test custom_plotter tool when the
|
122
|
+
# simulation results are available
|
123
|
+
##########################################
|
124
|
+
prompt = "Plot only CRP related species."
|
125
|
+
|
126
|
+
# Update state
|
127
|
+
app.update_state(config, {"llm_model": "gpt-4o-mini"}
|
128
|
+
)
|
129
|
+
# Test the tool get_modelinfo
|
130
|
+
response = app.invoke(
|
131
|
+
{"messages": [HumanMessage(content=prompt)]},
|
132
|
+
config=config
|
133
|
+
)
|
134
|
+
assistant_msg = response["messages"][-1].content
|
135
|
+
current_state = app.get_state(config)
|
136
|
+
# Get the messages from the current state
|
137
|
+
# and reverse the order
|
138
|
+
reversed_messages = current_state.values["messages"][::-1]
|
139
|
+
# Loop through the reversed messages
|
140
|
+
# until a ToolMessage is found.
|
141
|
+
expected_artifact = ['CRP[serum]', 'CRPExtracellular']
|
142
|
+
expected_artifact += ['CRP Suppression (%)', 'CRP (% of baseline)']
|
143
|
+
expected_artifact += ['CRP[liver]']
|
144
|
+
predicted_artifact = []
|
145
|
+
for msg in reversed_messages:
|
146
|
+
if isinstance(msg, ToolMessage):
|
147
|
+
# Work on the message if it is a ToolMessage
|
148
|
+
# These may contain additional visuals that
|
149
|
+
# need to be displayed to the user.
|
150
|
+
if msg.name == "custom_plotter":
|
151
|
+
predicted_artifact = msg.artifact
|
152
|
+
break
|
153
|
+
# Check if the two artifacts are equal
|
154
|
+
# assert expected_artifact in predicted_artifact
|
155
|
+
assert set(expected_artifact).issubset(set(predicted_artifact))
|
156
|
+
##########################################
|
157
|
+
# Test custom_plotter tool when the
|
158
|
+
# simulation results are available but
|
159
|
+
# the species is not available
|
160
|
+
##########################################
|
161
|
+
prompt = "Plot the species `TP53`."
|
162
|
+
|
163
|
+
# Update state
|
164
|
+
app.update_state(config, {"llm_model": "gpt-4o-mini"}
|
165
|
+
)
|
166
|
+
# Test the tool get_modelinfo
|
167
|
+
response = app.invoke(
|
168
|
+
{"messages": [HumanMessage(content=prompt)]},
|
169
|
+
config=config
|
170
|
+
)
|
171
|
+
assistant_msg = response["messages"][-1].content
|
172
|
+
# print (response["messages"])
|
173
|
+
current_state = app.get_state(config)
|
174
|
+
# Get the messages from the current state
|
175
|
+
# and reverse the order
|
176
|
+
reversed_messages = current_state.values["messages"][::-1]
|
177
|
+
# Loop through the reversed messages until a
|
178
|
+
# ToolMessage is found.
|
179
|
+
predicted_artifact = []
|
180
|
+
for msg in reversed_messages:
|
181
|
+
if isinstance(msg, ToolMessage):
|
182
|
+
# Work on the message if it is a ToolMessage
|
183
|
+
# These may contain additional visuals that
|
184
|
+
# need to be displayed to the user.
|
185
|
+
if msg.name == "custom_plotter":
|
186
|
+
predicted_artifact = msg.artifact
|
187
|
+
break
|
188
|
+
# Check if the the predicted artifact is `None`
|
189
|
+
assert predicted_artifact is None
|
@@ -0,0 +1,57 @@
|
|
1
|
+
'''
|
2
|
+
This file contains the unit tests for the BioModel class.
|
3
|
+
'''
|
4
|
+
|
5
|
+
from typing import List, Dict, Union, Optional
|
6
|
+
from pydantic import Field
|
7
|
+
import pytest
|
8
|
+
from ..models.sys_bio_model import SysBioModel
|
9
|
+
|
10
|
+
class TestBioModel(SysBioModel):
|
11
|
+
'''
|
12
|
+
A test BioModel class for unit testing.
|
13
|
+
'''
|
14
|
+
|
15
|
+
biomodel_id: Optional[int] = Field(None, description="BioModel ID of the model")
|
16
|
+
sbml_file_path: Optional[str] = Field(None, description="Path to an SBML file")
|
17
|
+
name: Optional[str] = Field(..., description="Name of the model")
|
18
|
+
description: Optional[str] = Field("", description="Description of the model")
|
19
|
+
|
20
|
+
def get_model_metadata(self) -> Dict[str, Union[str, int]]:
|
21
|
+
'''
|
22
|
+
Get the metadata of the model.
|
23
|
+
'''
|
24
|
+
return self.biomodel_id
|
25
|
+
|
26
|
+
def simulate(self,
|
27
|
+
parameters: Dict[str, Union[float, int]],
|
28
|
+
duration: Union[int, float]) -> List[float]:
|
29
|
+
'''
|
30
|
+
Simulate the model.
|
31
|
+
'''
|
32
|
+
param1 = parameters.get('param1', 0.0)
|
33
|
+
param2 = parameters.get('param2', 0.0)
|
34
|
+
return [param1 + param2 * t for t in range(int(duration))]
|
35
|
+
|
36
|
+
def test_get_model_metadata():
|
37
|
+
'''
|
38
|
+
Test the get_model_metadata method of the BioModel class.
|
39
|
+
'''
|
40
|
+
model = TestBioModel(biomodel_id=123, name="Test Model", description="A test model")
|
41
|
+
metadata = model.get_model_metadata()
|
42
|
+
assert metadata == 123
|
43
|
+
|
44
|
+
def test_check_biomodel_id_or_sbml_file_path():
|
45
|
+
'''
|
46
|
+
Test the check_biomodel_id_or_sbml_file_path method of the BioModel class.
|
47
|
+
'''
|
48
|
+
with pytest.raises(ValueError):
|
49
|
+
TestBioModel(name="Test Model", description="A test model")
|
50
|
+
|
51
|
+
def test_simulate():
|
52
|
+
'''
|
53
|
+
Test the simulate method of the BioModel class.
|
54
|
+
'''
|
55
|
+
model = TestBioModel(biomodel_id=123, name="Test Model", description="A test model")
|
56
|
+
results = model.simulate(parameters={'param1': 1.0, 'param2': 2.0}, duration=4.0)
|
57
|
+
assert results == [1.0, 3.0, 5.0, 7.0]
|
@@ -0,0 +1,23 @@
|
|
1
|
+
'''
|
2
|
+
Test cases for the search_studies
|
3
|
+
'''
|
4
|
+
|
5
|
+
# from ..tools.search_studies import search_studies
|
6
|
+
from aiagents4pharma.talk2cells.agents.scp_agent import get_app
|
7
|
+
from langchain_core.messages import HumanMessage
|
8
|
+
|
9
|
+
def test_agent_scp():
|
10
|
+
'''
|
11
|
+
Test the agent_scp.
|
12
|
+
'''
|
13
|
+
unique_id = 12345
|
14
|
+
app = get_app(unique_id)
|
15
|
+
config = {"configurable": {"thread_id": unique_id}}
|
16
|
+
prompt = "Search for studies on Crohns Disease."
|
17
|
+
response = app.invoke(
|
18
|
+
{"messages": [HumanMessage(content=prompt)]},
|
19
|
+
config=config
|
20
|
+
)
|
21
|
+
assistant_msg = response["messages"][-1].content
|
22
|
+
# Check if the assistant message is a string
|
23
|
+
assert isinstance(assistant_msg, str)
|
File without changes
|
File without changes
|
@@ -0,0 +1,242 @@
|
|
1
|
+
"""
|
2
|
+
Test cases for datasets/primekg_loader.py
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
import shutil
|
7
|
+
import pytest
|
8
|
+
from ..datasets.biobridge_primekg import BioBridgePrimeKG
|
9
|
+
|
10
|
+
# Remove the data folder for testing if it exists
|
11
|
+
PRIMEKG_LOCAL_DIR = "../data/primekg_test/"
|
12
|
+
LOCAL_DIR = "../data/biobridge_primekg_test/"
|
13
|
+
shutil.rmtree(LOCAL_DIR, ignore_errors=True)
|
14
|
+
|
15
|
+
@pytest.fixture(name="biobridge_primekg")
|
16
|
+
def biobridge_primekg_fixture():
|
17
|
+
"""
|
18
|
+
Fixture for creating an instance of PrimeKG.
|
19
|
+
"""
|
20
|
+
return BioBridgePrimeKG(primekg_dir=PRIMEKG_LOCAL_DIR,
|
21
|
+
local_dir=LOCAL_DIR)
|
22
|
+
|
23
|
+
def test_download_primekg(biobridge_primekg):
|
24
|
+
"""
|
25
|
+
Test the loading method of the BioBridge-PrimeKG class by downloading data from repository.
|
26
|
+
"""
|
27
|
+
# Load BioBridge-PrimeKG data
|
28
|
+
biobridge_primekg.load_data()
|
29
|
+
primekg_nodes = biobridge_primekg.get_primekg().get_nodes()
|
30
|
+
primekg_edges = biobridge_primekg.get_primekg().get_edges()
|
31
|
+
biobridge_data_config = biobridge_primekg.get_data_config()
|
32
|
+
biobridge_emb_dict = biobridge_primekg.get_node_embeddings()
|
33
|
+
biobridge_triplets = biobridge_primekg.get_primekg_triplets()
|
34
|
+
biobridge_splits = biobridge_primekg.get_train_test_split()
|
35
|
+
biobridge_node_info = biobridge_primekg.get_node_info_dict()
|
36
|
+
|
37
|
+
# Check if the local directories exists
|
38
|
+
assert os.path.exists(biobridge_primekg.primekg_dir)
|
39
|
+
assert os.path.exists(biobridge_primekg.local_dir)
|
40
|
+
# Check if downloaded and processed files exist
|
41
|
+
# PrimeKG files
|
42
|
+
files = ["nodes.tab", "primekg_nodes.tsv.gz",
|
43
|
+
"edges.csv", "primekg_edges.tsv.gz"]
|
44
|
+
for file in files:
|
45
|
+
path = f"{biobridge_primekg.primekg_dir}/{file}"
|
46
|
+
assert os.path.exists(path)
|
47
|
+
# BioBridge data config
|
48
|
+
assert os.path.exists(f"{biobridge_primekg.local_dir}/data_config.json")
|
49
|
+
# BioBridge embeddings
|
50
|
+
files = [
|
51
|
+
"protein.pkl",
|
52
|
+
"mf.pkl",
|
53
|
+
"cc.pkl",
|
54
|
+
"bp.pkl",
|
55
|
+
"drug.pkl",
|
56
|
+
"disease.pkl",
|
57
|
+
"embedding_dict.pkl"
|
58
|
+
]
|
59
|
+
for file in files:
|
60
|
+
path = f"{biobridge_primekg.local_dir}/embeddings/{file}"
|
61
|
+
assert os.path.exists(path)
|
62
|
+
# BioBridge processed files
|
63
|
+
files = [
|
64
|
+
"protein.csv",
|
65
|
+
"mf.csv",
|
66
|
+
"cc.csv",
|
67
|
+
"bp.csv",
|
68
|
+
"drug.csv",
|
69
|
+
"disease.csv",
|
70
|
+
"triplet_full.tsv.gz",
|
71
|
+
"triplet_full_altered.tsv.gz",
|
72
|
+
"node_train.tsv.gz",
|
73
|
+
"triplet_train.tsv.gz",
|
74
|
+
"node_test.tsv.gz",
|
75
|
+
"triplet_test.tsv.gz",
|
76
|
+
]
|
77
|
+
for file in files:
|
78
|
+
path = f"{biobridge_primekg.local_dir}/processed/{file}"
|
79
|
+
assert os.path.exists(path)
|
80
|
+
# Check processed PrimeKG dataframes
|
81
|
+
# Nodes
|
82
|
+
assert primekg_nodes is not None
|
83
|
+
assert len(primekg_nodes) > 0
|
84
|
+
assert primekg_nodes.shape[0] == 129375
|
85
|
+
# Edges
|
86
|
+
assert primekg_edges is not None
|
87
|
+
assert len(primekg_edges) > 0
|
88
|
+
assert primekg_edges.shape[0] == 8100498
|
89
|
+
# Check processed BioBridge data config
|
90
|
+
assert biobridge_data_config is not None
|
91
|
+
assert len(biobridge_data_config) > 0
|
92
|
+
assert len(biobridge_data_config['node_type']) == 10
|
93
|
+
assert len(biobridge_data_config['relation_type']) == 18
|
94
|
+
assert len(biobridge_data_config['emb_dim']) == 6
|
95
|
+
# Check processed BioBridge embeddings
|
96
|
+
assert biobridge_emb_dict is not None
|
97
|
+
assert len(biobridge_emb_dict) > 0
|
98
|
+
assert len(biobridge_emb_dict) == 85466
|
99
|
+
# Check processed BioBridge triplets
|
100
|
+
assert biobridge_triplets is not None
|
101
|
+
assert len(biobridge_triplets) > 0
|
102
|
+
assert biobridge_triplets.shape[0] == 3904610
|
103
|
+
assert list(biobridge_splits.keys()) == ['train', 'node_train', 'test', 'node_test']
|
104
|
+
assert len(biobridge_splits['train']) == 3510930
|
105
|
+
assert len(biobridge_splits['node_train']) == 76486
|
106
|
+
assert len(biobridge_splits['test']) == 393680
|
107
|
+
assert len(biobridge_splits['node_test']) == 8495
|
108
|
+
# Check node info dictionary
|
109
|
+
assert list(biobridge_node_info.keys()) == ['gene/protein',
|
110
|
+
'molecular_function',
|
111
|
+
'cellular_component',
|
112
|
+
'biological_process',
|
113
|
+
'drug',
|
114
|
+
'disease']
|
115
|
+
assert len(biobridge_node_info['gene/protein']) == 19162
|
116
|
+
assert len(biobridge_node_info['molecular_function']) == 10966
|
117
|
+
assert len(biobridge_node_info['cellular_component']) == 4013
|
118
|
+
assert len(biobridge_node_info['biological_process']) == 27478
|
119
|
+
assert len(biobridge_node_info['drug']) == 6948
|
120
|
+
assert len(biobridge_node_info['disease']) == 44133
|
121
|
+
|
122
|
+
|
123
|
+
def test_load_existing_primekg(biobridge_primekg):
|
124
|
+
"""
|
125
|
+
Test the loading method of the BioBridge-PrimeKG class by loading existing data in local.
|
126
|
+
"""
|
127
|
+
# Load BioBridge-PrimeKG data
|
128
|
+
biobridge_primekg.load_data()
|
129
|
+
primekg_nodes = biobridge_primekg.get_primekg().get_nodes()
|
130
|
+
primekg_edges = biobridge_primekg.get_primekg().get_edges()
|
131
|
+
biobridge_data_config = biobridge_primekg.get_data_config()
|
132
|
+
biobridge_emb_dict = biobridge_primekg.get_node_embeddings()
|
133
|
+
biobridge_triplets = biobridge_primekg.get_primekg_triplets()
|
134
|
+
biobridge_splits = biobridge_primekg.get_train_test_split()
|
135
|
+
biobridge_node_info = biobridge_primekg.get_node_info_dict()
|
136
|
+
|
137
|
+
# Check if the local directories exists
|
138
|
+
assert os.path.exists(biobridge_primekg.primekg_dir)
|
139
|
+
assert os.path.exists(biobridge_primekg.local_dir)
|
140
|
+
# Check if downloaded and processed files exist
|
141
|
+
# PrimeKG files
|
142
|
+
files = ["nodes.tab", "primekg_nodes.tsv.gz",
|
143
|
+
"edges.csv", "primekg_edges.tsv.gz"]
|
144
|
+
for file in files:
|
145
|
+
path = f"{biobridge_primekg.primekg_dir}/{file}"
|
146
|
+
assert os.path.exists(path)
|
147
|
+
# BioBridge data config
|
148
|
+
assert os.path.exists(f"{biobridge_primekg.local_dir}/data_config.json")
|
149
|
+
# BioBridge embeddings
|
150
|
+
files = [
|
151
|
+
"protein.pkl",
|
152
|
+
"mf.pkl",
|
153
|
+
"cc.pkl",
|
154
|
+
"bp.pkl",
|
155
|
+
"drug.pkl",
|
156
|
+
"disease.pkl",
|
157
|
+
"embedding_dict.pkl"
|
158
|
+
]
|
159
|
+
for file in files:
|
160
|
+
path = f"{biobridge_primekg.local_dir}/embeddings/{file}"
|
161
|
+
assert os.path.exists(path)
|
162
|
+
# BioBridge processed files
|
163
|
+
files = [
|
164
|
+
"protein.csv",
|
165
|
+
"mf.csv",
|
166
|
+
"cc.csv",
|
167
|
+
"bp.csv",
|
168
|
+
"drug.csv",
|
169
|
+
"disease.csv",
|
170
|
+
"triplet_full.tsv.gz",
|
171
|
+
"triplet_full_altered.tsv.gz",
|
172
|
+
"node_train.tsv.gz",
|
173
|
+
"triplet_train.tsv.gz",
|
174
|
+
"node_test.tsv.gz",
|
175
|
+
"triplet_test.tsv.gz",
|
176
|
+
]
|
177
|
+
for file in files:
|
178
|
+
path = f"{biobridge_primekg.local_dir}/processed/{file}"
|
179
|
+
assert os.path.exists(path)
|
180
|
+
# Check processed PrimeKG dataframes
|
181
|
+
# Nodes
|
182
|
+
assert primekg_nodes is not None
|
183
|
+
assert len(primekg_nodes) > 0
|
184
|
+
assert primekg_nodes.shape[0] == 129375
|
185
|
+
# Edges
|
186
|
+
assert primekg_edges is not None
|
187
|
+
assert len(primekg_edges) > 0
|
188
|
+
assert primekg_edges.shape[0] == 8100498
|
189
|
+
# Check processed BioBridge data config
|
190
|
+
assert biobridge_data_config is not None
|
191
|
+
assert len(biobridge_data_config) > 0
|
192
|
+
assert len(biobridge_data_config['node_type']) == 10
|
193
|
+
assert len(biobridge_data_config['relation_type']) == 18
|
194
|
+
assert len(biobridge_data_config['emb_dim']) == 6
|
195
|
+
# Check processed BioBridge embeddings
|
196
|
+
assert biobridge_emb_dict is not None
|
197
|
+
assert len(biobridge_emb_dict) > 0
|
198
|
+
assert len(biobridge_emb_dict) == 85466
|
199
|
+
# Check processed BioBridge triplets
|
200
|
+
assert biobridge_triplets is not None
|
201
|
+
assert len(biobridge_triplets) > 0
|
202
|
+
assert biobridge_triplets.shape[0] == 3904610
|
203
|
+
assert list(biobridge_splits.keys()) == ['train', 'node_train', 'test', 'node_test']
|
204
|
+
assert len(biobridge_splits['train']) == 3510930
|
205
|
+
assert len(biobridge_splits['node_train']) == 76486
|
206
|
+
assert len(biobridge_splits['test']) == 393680
|
207
|
+
assert len(biobridge_splits['node_test']) == 8495
|
208
|
+
# Check node info dictionary
|
209
|
+
assert list(biobridge_node_info.keys()) == ['gene/protein',
|
210
|
+
'molecular_function',
|
211
|
+
'cellular_component',
|
212
|
+
'biological_process',
|
213
|
+
'drug',
|
214
|
+
'disease']
|
215
|
+
assert len(biobridge_node_info['gene/protein']) == 19162
|
216
|
+
assert len(biobridge_node_info['molecular_function']) == 10966
|
217
|
+
assert len(biobridge_node_info['cellular_component']) == 4013
|
218
|
+
assert len(biobridge_node_info['biological_process']) == 27478
|
219
|
+
assert len(biobridge_node_info['drug']) == 6948
|
220
|
+
assert len(biobridge_node_info['disease']) == 44133
|
221
|
+
|
222
|
+
# def test_load_existing_primekg_with_negative_triplets(biobridge_primekg):
|
223
|
+
# """
|
224
|
+
# Test the loading method of the BioBridge-PrimeKG class by loading existing data in local.
|
225
|
+
# In addition, it builds negative triplets for training data.
|
226
|
+
# """
|
227
|
+
# # Load BioBridge-PrimeKG data
|
228
|
+
# # Using 1 negative sample per positive triplet
|
229
|
+
# biobridge_primekg.load_data(build_neg_triplest=True, n_neg_samples=1)
|
230
|
+
# biobridge_neg_triplets = biobridge_primekg.get_primekg_triplets_negative()
|
231
|
+
|
232
|
+
# # Check if the local directories exists
|
233
|
+
# assert os.path.exists(biobridge_primekg.primekg_dir)
|
234
|
+
# assert os.path.exists(biobridge_primekg.local_dir)
|
235
|
+
# # Check if downloaded and processed files exist
|
236
|
+
# path = f"{biobridge_primekg.local_dir}/processed/triplet_train_negative.tsv.gz"
|
237
|
+
# assert os.path.exists(path)
|
238
|
+
# # Check processed BioBridge triplets
|
239
|
+
# assert biobridge_neg_triplets is not None
|
240
|
+
# assert len(biobridge_neg_triplets) > 0
|
241
|
+
# assert biobridge_neg_triplets.shape[0] == 3510930
|
242
|
+
# assert len(biobridge_neg_triplets.negative_tail_index[0]) == 1
|
@@ -0,0 +1,29 @@
|
|
1
|
+
"""
|
2
|
+
Test cases for datasets/dataset.py
|
3
|
+
"""
|
4
|
+
|
5
|
+
from ..datasets.dataset import Dataset
|
6
|
+
|
7
|
+
class MockDataset(Dataset):
|
8
|
+
"""
|
9
|
+
Mock dataset class for testing purposes.
|
10
|
+
"""
|
11
|
+
def setup(self):
|
12
|
+
pass
|
13
|
+
|
14
|
+
def load_data(self):
|
15
|
+
pass
|
16
|
+
|
17
|
+
def test_dataset_setup():
|
18
|
+
"""
|
19
|
+
Test the setup method of the Dataset class.
|
20
|
+
"""
|
21
|
+
dataset = MockDataset()
|
22
|
+
assert dataset.setup() is None
|
23
|
+
|
24
|
+
def test_dataset_load_data():
|
25
|
+
"""
|
26
|
+
Test the load_data method of the Dataset class.
|
27
|
+
"""
|
28
|
+
dataset = MockDataset()
|
29
|
+
assert dataset.load_data() is None
|
@@ -0,0 +1,73 @@
|
|
1
|
+
"""
|
2
|
+
Test cases for datasets/primekg_loader.py
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
import shutil
|
7
|
+
import pytest
|
8
|
+
from ..datasets.primekg import PrimeKG
|
9
|
+
|
10
|
+
# Remove the data folder for testing if it exists
|
11
|
+
LOCAL_DIR = "../data/primekg_test/"
|
12
|
+
shutil.rmtree(LOCAL_DIR, ignore_errors=True)
|
13
|
+
|
14
|
+
@pytest.fixture(name="primekg")
|
15
|
+
def primekg_fixture():
|
16
|
+
"""
|
17
|
+
Fixture for creating an instance of PrimeKG.
|
18
|
+
"""
|
19
|
+
return PrimeKG(local_dir=LOCAL_DIR)
|
20
|
+
|
21
|
+
def test_download_primekg(primekg):
|
22
|
+
"""
|
23
|
+
Test the loading method of the PrimeKG class by downloading PrimeKG from server.
|
24
|
+
"""
|
25
|
+
# Load PrimeKG data
|
26
|
+
primekg.load_data()
|
27
|
+
primekg_nodes = primekg.get_nodes()
|
28
|
+
primekg_edges = primekg.get_edges()
|
29
|
+
|
30
|
+
# Check if the local directory exists
|
31
|
+
assert os.path.exists(primekg.local_dir)
|
32
|
+
# Check if downloaded and processed files exist
|
33
|
+
files = ["nodes.tab", f"{primekg.name}_nodes.tsv.gz",
|
34
|
+
"edges.csv", f"{primekg.name}_edges.tsv.gz"]
|
35
|
+
for file in files:
|
36
|
+
path = f"{primekg.local_dir}/{file}"
|
37
|
+
assert os.path.exists(path)
|
38
|
+
# Check processed PrimeKG dataframes
|
39
|
+
# Nodes
|
40
|
+
assert primekg_nodes is not None
|
41
|
+
assert len(primekg_nodes) > 0
|
42
|
+
assert primekg_nodes.shape[0] == 129375
|
43
|
+
# Edges
|
44
|
+
assert primekg_edges is not None
|
45
|
+
assert len(primekg_edges) > 0
|
46
|
+
assert primekg_edges.shape[0] == 8100498
|
47
|
+
|
48
|
+
def test_load_existing_primekg(primekg):
|
49
|
+
"""
|
50
|
+
Test the loading method of the PrimeKG class by loading existing PrimeKG in local.
|
51
|
+
"""
|
52
|
+
# Load PrimeKG data
|
53
|
+
primekg.load_data()
|
54
|
+
primekg_nodes = primekg.get_nodes()
|
55
|
+
primekg_edges = primekg.get_edges()
|
56
|
+
|
57
|
+
# Check if the local directory exists
|
58
|
+
assert os.path.exists(primekg.local_dir)
|
59
|
+
# Check if downloaded and processed files exist
|
60
|
+
files = ["nodes.tab", f"{primekg.name}_nodes.tsv.gz",
|
61
|
+
"edges.csv", f"{primekg.name}_edges.tsv.gz"]
|
62
|
+
for file in files:
|
63
|
+
path = f"{primekg.local_dir}/{file}"
|
64
|
+
assert os.path.exists(path)
|
65
|
+
# Check processed PrimeKG dataframes
|
66
|
+
# Nodes
|
67
|
+
assert primekg_nodes is not None
|
68
|
+
assert len(primekg_nodes) > 0
|
69
|
+
assert primekg_nodes.shape[0] == 129375
|
70
|
+
# Edges
|
71
|
+
assert primekg_edges is not None
|
72
|
+
assert len(primekg_edges) > 0
|
73
|
+
assert primekg_edges.shape[0] == 8100498
|
@@ -0,0 +1,116 @@
|
|
1
|
+
"""
|
2
|
+
Test cases for datasets/starkqa_primekg_loader.py
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
import shutil
|
7
|
+
import pytest
|
8
|
+
from ..datasets.starkqa_primekg import StarkQAPrimeKG
|
9
|
+
|
10
|
+
# Remove the data folder for testing if it exists
|
11
|
+
LOCAL_DIR = "../data/starkqa_primekg_test/"
|
12
|
+
shutil.rmtree(LOCAL_DIR, ignore_errors=True)
|
13
|
+
|
14
|
+
@pytest.fixture(name="starkqa_primekg")
|
15
|
+
def starkqa_primekg_fixture():
|
16
|
+
"""
|
17
|
+
Fixture for creating an instance of StarkQAPrimeKGData.
|
18
|
+
"""
|
19
|
+
return StarkQAPrimeKG(local_dir=LOCAL_DIR)
|
20
|
+
|
21
|
+
def test_download_starkqa_primekg(starkqa_primekg):
|
22
|
+
"""
|
23
|
+
Test the loading method of the StarkQAPrimeKGLoaderTool class by downloading files
|
24
|
+
from HuggingFace Hub.
|
25
|
+
"""
|
26
|
+
# Load StarkQA PrimeKG data
|
27
|
+
starkqa_primekg.load_data()
|
28
|
+
starkqa_df = starkqa_primekg.get_starkqa()
|
29
|
+
primekg_node_info = starkqa_primekg.get_starkqa_node_info()
|
30
|
+
split_idx = starkqa_primekg.get_starkqa_split_indicies()
|
31
|
+
query_embeddings = starkqa_primekg.get_query_embeddings()
|
32
|
+
node_embeddings = starkqa_primekg.get_node_embeddings()
|
33
|
+
|
34
|
+
# Check if the local directory exists
|
35
|
+
assert os.path.exists(starkqa_primekg.local_dir)
|
36
|
+
# Check if downloaded files exist in the local directory
|
37
|
+
files = ['qa/prime/split/test-0.1.index',
|
38
|
+
'qa/prime/split/test.index',
|
39
|
+
'qa/prime/split/train.index',
|
40
|
+
'qa/prime/split/val.index',
|
41
|
+
'qa/prime/stark_qa/stark_qa.csv',
|
42
|
+
'qa/prime/stark_qa/stark_qa_human_generated_eval.csv',
|
43
|
+
'skb/prime/processed.zip']
|
44
|
+
for file in files:
|
45
|
+
path = f"{starkqa_primekg.local_dir}/{file}"
|
46
|
+
assert os.path.exists(path)
|
47
|
+
# Check dataframe
|
48
|
+
assert starkqa_df is not None
|
49
|
+
assert len(starkqa_df) > 0
|
50
|
+
assert starkqa_df.shape[0] == 11204
|
51
|
+
# Check node information
|
52
|
+
assert primekg_node_info is not None
|
53
|
+
assert len(primekg_node_info) == 129375
|
54
|
+
# Check split indices
|
55
|
+
assert list(split_idx.keys()) == ['train', 'val', 'test', 'test-0.1']
|
56
|
+
assert len(split_idx['train']) == 6162
|
57
|
+
assert len(split_idx['val']) == 2241
|
58
|
+
assert len(split_idx['test']) == 2801
|
59
|
+
assert len(split_idx['test-0.1']) == 280
|
60
|
+
# Check query embeddings
|
61
|
+
assert query_embeddings is not None
|
62
|
+
assert len(query_embeddings) == 11204
|
63
|
+
assert query_embeddings[0].shape[1] == 1536
|
64
|
+
# Check node embeddings
|
65
|
+
assert node_embeddings is not None
|
66
|
+
assert len(node_embeddings) == 129375
|
67
|
+
assert node_embeddings[0].shape[1] == 1536
|
68
|
+
|
69
|
+
def test_load_existing_starkqa_primekg(starkqa_primekg):
|
70
|
+
"""
|
71
|
+
|
72
|
+
Test the loading method of the StarkQAPrimeKGLoaderTool class by loading existing files
|
73
|
+
in the local directory.
|
74
|
+
"""
|
75
|
+
# Load StarkQA PrimeKG data
|
76
|
+
starkqa_primekg.load_data()
|
77
|
+
starkqa_df = starkqa_primekg.get_starkqa()
|
78
|
+
primekg_node_info = starkqa_primekg.get_starkqa_node_info()
|
79
|
+
split_idx = starkqa_primekg.get_starkqa_split_indicies()
|
80
|
+
query_embeddings = starkqa_primekg.get_query_embeddings()
|
81
|
+
node_embeddings = starkqa_primekg.get_node_embeddings()
|
82
|
+
|
83
|
+
# Check if the local directory exists
|
84
|
+
assert os.path.exists(starkqa_primekg.local_dir)
|
85
|
+
# Check if downloaded and processed files exist
|
86
|
+
files = ['qa/prime/split/test-0.1.index',
|
87
|
+
'qa/prime/split/test.index',
|
88
|
+
'qa/prime/split/train.index',
|
89
|
+
'qa/prime/split/val.index',
|
90
|
+
'qa/prime/stark_qa/stark_qa.csv',
|
91
|
+
'qa/prime/stark_qa/stark_qa_human_generated_eval.csv',
|
92
|
+
'skb/prime/processed.zip']
|
93
|
+
for file in files:
|
94
|
+
path = f"{starkqa_primekg.local_dir}/{file}"
|
95
|
+
assert os.path.exists(path)
|
96
|
+
# Check dataframe
|
97
|
+
assert starkqa_df is not None
|
98
|
+
assert len(starkqa_df) > 0
|
99
|
+
assert starkqa_df.shape[0] == 11204
|
100
|
+
# Check node information
|
101
|
+
assert primekg_node_info is not None
|
102
|
+
assert len(primekg_node_info) == 129375
|
103
|
+
# Check split indices
|
104
|
+
assert list(split_idx.keys()) == ['train', 'val', 'test', 'test-0.1']
|
105
|
+
assert len(split_idx['train']) == 6162
|
106
|
+
assert len(split_idx['val']) == 2241
|
107
|
+
assert len(split_idx['test']) == 2801
|
108
|
+
assert len(split_idx['test-0.1']) == 280
|
109
|
+
# Check query embeddings
|
110
|
+
assert query_embeddings is not None
|
111
|
+
assert len(query_embeddings) == 11204
|
112
|
+
assert query_embeddings[0].shape[1] == 1536
|
113
|
+
# Check node embeddings
|
114
|
+
assert node_embeddings is not None
|
115
|
+
assert len(node_embeddings) == 129375
|
116
|
+
assert node_embeddings[0].shape[1] == 1536
|
@@ -0,0 +1,47 @@
|
|
1
|
+
"""
|
2
|
+
Test cases for utils/embeddings/embeddings.py
|
3
|
+
"""
|
4
|
+
|
5
|
+
import pytest
|
6
|
+
from ..utils.embeddings.embeddings import Embeddings
|
7
|
+
|
8
|
+
class TestEmbeddings(Embeddings):
|
9
|
+
"""Test implementation of the Embeddings interface for testing purposes."""
|
10
|
+
|
11
|
+
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
12
|
+
return [[0.1, 0.2, 0.3] for _ in texts]
|
13
|
+
|
14
|
+
def embed_query(self, text: str) -> list[float]:
|
15
|
+
return [0.1, 0.2, 0.3]
|
16
|
+
|
17
|
+
def test_embed_documents():
|
18
|
+
"""Test embedding documents using the Embeddings interface."""
|
19
|
+
embeddings = TestEmbeddings()
|
20
|
+
texts = ["text1", "text2"]
|
21
|
+
result = embeddings.embed_documents(texts)
|
22
|
+
assert result == [[0.1, 0.2, 0.3], [0.1, 0.2, 0.3]]
|
23
|
+
|
24
|
+
|
25
|
+
def test_embed_query():
|
26
|
+
"""Test embedding a query using the Embeddings interface."""
|
27
|
+
embeddings = TestEmbeddings()
|
28
|
+
text = "query"
|
29
|
+
result = embeddings.embed_query(text)
|
30
|
+
assert result == [0.1, 0.2, 0.3]
|
31
|
+
|
32
|
+
@pytest.mark.asyncio
|
33
|
+
async def test_aembed_documents():
|
34
|
+
"""Test asynchronous embedding of documents using the Embeddings interface."""
|
35
|
+
embeddings = TestEmbeddings()
|
36
|
+
texts = ["text1", "text2"]
|
37
|
+
result = await embeddings.aembed_documents(texts)
|
38
|
+
assert result == [[0.1, 0.2, 0.3], [0.1, 0.2, 0.3]]
|
39
|
+
|
40
|
+
|
41
|
+
@pytest.mark.asyncio
|
42
|
+
async def test_aembed_query():
|
43
|
+
"""Test asynchronous embedding of a query using the Embeddings interface."""
|
44
|
+
embeddings = TestEmbeddings()
|
45
|
+
text = "query"
|
46
|
+
result = await embeddings.aembed_query(text)
|
47
|
+
assert result == [0.1, 0.2, 0.3]
|
@@ -0,0 +1,45 @@
|
|
1
|
+
"""
|
2
|
+
Test cases for utils/embeddings/huggingface.py
|
3
|
+
"""
|
4
|
+
|
5
|
+
import pytest
|
6
|
+
from ..utils.embeddings.huggingface import EmbeddingWithHuggingFace
|
7
|
+
|
8
|
+
@pytest.fixture(name="embedding_model")
|
9
|
+
def embedding_model_fixture():
|
10
|
+
"""Return the configuration object for the HuggingFace embedding model and model object"""
|
11
|
+
return EmbeddingWithHuggingFace(
|
12
|
+
model_name="NeuML/pubmedbert-base-embeddings",
|
13
|
+
model_cache_dir="../../cache",
|
14
|
+
truncation=True,
|
15
|
+
)
|
16
|
+
|
17
|
+
def test_embedding_with_huggingface_embed_documents(embedding_model):
|
18
|
+
"""Test embedding documents using the EmbeddingWithHuggingFace class."""
|
19
|
+
# Perform embedding
|
20
|
+
texts = ["Adalimumab", "Infliximab", "Vedolizumab"]
|
21
|
+
result = embedding_model.embed_documents(texts)
|
22
|
+
# Check the result
|
23
|
+
assert len(result) == 3
|
24
|
+
assert len(result[0]) == 768
|
25
|
+
|
26
|
+
def test_embedding_with_huggingface_embed_query(embedding_model):
|
27
|
+
"""Test embedding a query using the EmbeddingWithHuggingFace class."""
|
28
|
+
# Perform embedding
|
29
|
+
text = "Adalimumab"
|
30
|
+
result = embedding_model.embed_query(text)
|
31
|
+
# Check the result
|
32
|
+
assert len(result) == 768
|
33
|
+
|
34
|
+
|
35
|
+
def test_embedding_with_huggingface_failed():
|
36
|
+
"""Test embedding documents using the EmbeddingWithHuggingFace class."""
|
37
|
+
# Check if the model is available on HuggingFace Hub
|
38
|
+
model_name = "aiagents4pharma/embeddings"
|
39
|
+
err_msg = f"Model {model_name} is not available on HuggingFace Hub."
|
40
|
+
with pytest.raises(ValueError, match=err_msg):
|
41
|
+
EmbeddingWithHuggingFace(
|
42
|
+
model_name=model_name,
|
43
|
+
model_cache_dir="../../cache",
|
44
|
+
truncation=True,
|
45
|
+
)
|
@@ -0,0 +1,40 @@
|
|
1
|
+
"""
|
2
|
+
Test cases for utils/embeddings/sentence_transformer.py
|
3
|
+
"""
|
4
|
+
|
5
|
+
import pytest
|
6
|
+
import numpy as np
|
7
|
+
from ..utils.embeddings.sentence_transformer import EmbeddingWithSentenceTransformer
|
8
|
+
|
9
|
+
@pytest.fixture(name="embedding_model")
|
10
|
+
def embedding_model_fixture():
|
11
|
+
"""
|
12
|
+
Fixture for creating an instance of EmbeddingWithSentenceTransformer.
|
13
|
+
"""
|
14
|
+
model_name = "sentence-transformers/all-MiniLM-L6-v1" # Small model for testing
|
15
|
+
return EmbeddingWithSentenceTransformer(model_name=model_name)
|
16
|
+
|
17
|
+
def test_embed_documents(embedding_model):
|
18
|
+
"""
|
19
|
+
Test the embed_documents method of EmbeddingWithSentenceTransformer class.
|
20
|
+
"""
|
21
|
+
# Perform embedding
|
22
|
+
texts = ["This is a test sentence.", "Another test sentence."]
|
23
|
+
embeddings = embedding_model.embed_documents(texts)
|
24
|
+
# Check the result
|
25
|
+
assert len(embeddings) == len(texts)
|
26
|
+
assert len(embeddings[0]) > 0
|
27
|
+
assert len(embeddings[0]) == 384
|
28
|
+
assert embeddings.dtype == np.float32
|
29
|
+
|
30
|
+
def test_embed_query(embedding_model):
|
31
|
+
"""
|
32
|
+
Test the embed_query method of EmbeddingWithSentenceTransformer class.
|
33
|
+
"""
|
34
|
+
# Perform embedding
|
35
|
+
text = "This is a test query."
|
36
|
+
embedding = embedding_model.embed_query(text)
|
37
|
+
# Check the result
|
38
|
+
assert len(embedding) > 0
|
39
|
+
assert len(embedding) == 384
|
40
|
+
assert embedding.dtype == np.float32
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.2
|
2
2
|
Name: aiagents4pharma
|
3
|
-
Version: 1.8.
|
3
|
+
Version: 1.8.3
|
4
4
|
Summary: AI Agents for drug discovery, drug development, and other pharmaceutical R&D
|
5
5
|
Classifier: Programming Language :: Python :: 3
|
6
6
|
Classifier: License :: OSI Approved :: MIT License
|
@@ -1,12 +1,22 @@
|
|
1
|
-
aiagents4pharma/__init__.py,sha256=
|
1
|
+
aiagents4pharma/__init__.py,sha256=X-Mpbf4sfjMvoxiTTl4qFgCxrWZgtVgRjnjX6gbUtCg,173
|
2
|
+
aiagents4pharma/configs/__init__.py,sha256=hNkSrXw1Ix1HhkGn_aaidr2coBYySfM0Hm_pMeRcX7k,76
|
3
|
+
aiagents4pharma/configs/config.yaml,sha256=8y8uG6Dzx4-9jyb6hZ8r4lOJz5gA_sQhCiSCgXL5l7k,65
|
4
|
+
aiagents4pharma/configs/talk2biomodels/__init__.py,sha256=5ah__-8XyRblwT0U1ByRigNjt_GyCheu7zce4aM-eZE,68
|
5
|
+
aiagents4pharma/configs/talk2biomodels/agents/__init__.py,sha256=_ZoG8snICK2bidWtc2KOGs738LWg9_r66V9mOMnEb-E,71
|
6
|
+
aiagents4pharma/configs/talk2biomodels/agents/t2b_agent/__init__.py,sha256=-fAORvyFmG2iSvFOFDixmt9OTQRR58y89uhhu2EgbA8,46
|
7
|
+
aiagents4pharma/configs/talk2biomodels/agents/t2b_agent/default.yaml,sha256=-IwTmcZFlgUzxxHgcaR9jmbVKCP3IFEHtqzAD4pek6s,200
|
2
8
|
aiagents4pharma/talk2biomodels/__init__.py,sha256=9MuyHb5KTf5ufeyq7fu5xoLEwVT688DFOWuKuzdWG9o,140
|
3
9
|
aiagents4pharma/talk2biomodels/agents/__init__.py,sha256=sn5-fREjMdEvb-OUan3iOqrgYGjplNx3J8hYOaW0Po8,128
|
4
|
-
aiagents4pharma/talk2biomodels/agents/t2b_agent.py,sha256=
|
10
|
+
aiagents4pharma/talk2biomodels/agents/t2b_agent.py,sha256=nVWxHR-QMZDqDwxvDga_CvLo7LHP5cWCDl6lXCMcRO0,3264
|
5
11
|
aiagents4pharma/talk2biomodels/models/__init__.py,sha256=5fTHHm3PVloYPNKXbgNlcPgv3-u28ZquxGydFYDfhJA,122
|
6
12
|
aiagents4pharma/talk2biomodels/models/basico_model.py,sha256=js7ORLwbJPaIsko5oRToMMCh4l8LsN292OIvFzTfvRg,4946
|
7
13
|
aiagents4pharma/talk2biomodels/models/sys_bio_model.py,sha256=ylpPba2SA8kl68q3k1kJbiUdRYplPHykyslTQLDZ19I,1995
|
8
14
|
aiagents4pharma/talk2biomodels/states/__init__.py,sha256=YLg1-N0D9qyRRLRqwqfLCLAqZYDtMVZTfI8Y0b_4tbA,139
|
9
15
|
aiagents4pharma/talk2biomodels/states/state_talk2biomodels.py,sha256=cneAsNgtzDoL3_gkC4gf7h69BfBOY7rASpYWz4I8jjI,761
|
16
|
+
aiagents4pharma/talk2biomodels/tests/__init__.py,sha256=Jbw5tJxSrjGoaK5IX3pJWDCNzhrVQ10lkYq2oQ_KQD8,45
|
17
|
+
aiagents4pharma/talk2biomodels/tests/test_basico_model.py,sha256=uqhbojcA4RRTDRUAF9B9DzKCo3OOIOWMDK8IViG0gsM,2038
|
18
|
+
aiagents4pharma/talk2biomodels/tests/test_langgraph.py,sha256=z3bS_LE2s9kba2fFM5SWQy1uIalQBbCnGJuakQjnKTQ,7365
|
19
|
+
aiagents4pharma/talk2biomodels/tests/test_sys_bio_model.py,sha256=nA6bRT16627mw8qzrv7cHM9AByHb9F0kxAuwOpE-avA,1961
|
10
20
|
aiagents4pharma/talk2biomodels/tools/__init__.py,sha256=8hAT6z1OO8N9HRylh6fwoqyjYlGdpkngkElBNqH40Zo,237
|
11
21
|
aiagents4pharma/talk2biomodels/tools/ask_question.py,sha256=UEdT46DIFZHlAOF4cNX4_s7VjHvbbiGpNmEY2-XW2iA,2655
|
12
22
|
aiagents4pharma/talk2biomodels/tools/custom_plotter.py,sha256=BfLSivWqtRVZDeqD7RP5sgZP6B9rET47IWdNrA3gmIE,3825
|
@@ -19,24 +29,34 @@ aiagents4pharma/talk2cells/agents/__init__.py,sha256=38nK2a_lEFRjO3qD6Fo9a3983ZC
|
|
19
29
|
aiagents4pharma/talk2cells/agents/scp_agent.py,sha256=gDMfhUNWHa_XWOqm1Ql6yLAdI_7bnIk5sRYn43H2sYk,3090
|
20
30
|
aiagents4pharma/talk2cells/states/__init__.py,sha256=e4s8pHZaR6UC42DtmsOzCVms5gxp5QEzLE4bG54YYko,135
|
21
31
|
aiagents4pharma/talk2cells/states/state_talk2cells.py,sha256=en5LikmabPZA6lLVpmYXff0Q3Fno0N2PBSMxk3gLWaE,253
|
32
|
+
aiagents4pharma/talk2cells/tests/scp_agent/test_scp_agent.py,sha256=f8bsUXvEElvYM5KGvpPzI1xOR5Y_zpm3_Jzk7gCX1CE,732
|
22
33
|
aiagents4pharma/talk2cells/tools/__init__.py,sha256=38nK2a_lEFRjO3qD6Fo9a3983ZCYat6hmJKWY61y2Mo,128
|
23
34
|
aiagents4pharma/talk2cells/tools/scp_agent/__init__.py,sha256=s7g0lyH1lMD9pcWHLPtwRJRvzmTh2II7DrxyLulpjmQ,163
|
24
35
|
aiagents4pharma/talk2cells/tools/scp_agent/display_studies.py,sha256=6q59gh_NQaiOU2rn55A3sIIFKlXi4SK3iKgySvUDrtQ,600
|
25
36
|
aiagents4pharma/talk2cells/tools/scp_agent/search_studies.py,sha256=MLe-twtFnOu-P8P9diYq7jvHBHbWFRRCZLcfpUzqPMg,2806
|
37
|
+
aiagents4pharma/talk2competitors/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
26
38
|
aiagents4pharma/talk2knowledgegraphs/__init__.py,sha256=SW7Ys2A4eXyFtizNPdSw91SHOPVUBGBsrCQ7TqwSUL0,91
|
27
39
|
aiagents4pharma/talk2knowledgegraphs/datasets/__init__.py,sha256=L3gPuHskSegmtXskVrLIYr7FXe_ibKgJ2GGr1_Wok6k,173
|
28
40
|
aiagents4pharma/talk2knowledgegraphs/datasets/biobridge_primekg.py,sha256=QlzDXmXREoa9MA6-GwzqRjdzndQeGBAF11Td6NFk_9Y,23426
|
29
41
|
aiagents4pharma/talk2knowledgegraphs/datasets/dataset.py,sha256=-LaPLse8BkALqwFetNK7wch2dt9Dz6QKGKZKBKM6bIk,409
|
30
42
|
aiagents4pharma/talk2knowledgegraphs/datasets/primekg.py,sha256=KBMhCJ7yjMWqQJJctFYdpjYAlwv48Jl6i1dddXP4f08,7599
|
31
43
|
aiagents4pharma/talk2knowledgegraphs/datasets/starkqa_primekg.py,sha256=Y-6-nORsnBJlU6rH0skyfr9S9J4PfTWK-af_p5UuknQ,7483
|
44
|
+
aiagents4pharma/talk2knowledgegraphs/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
45
|
+
aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_biobridge_primekg.py,sha256=crH0eFA3P8P6IYzi1UWNa4YvRVrtlBzoScf9NaE1lDk,9827
|
46
|
+
aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_dataset.py,sha256=NFUlsZvhfIrkF4YenWfahrLK93Xhm5UYEGG_uYN2LVM,566
|
47
|
+
aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_primekg.py,sha256=Pvu0r93CpnhjkfMxc-EiVLpAJ04FdW9iTamCnetu654,2272
|
48
|
+
aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_starkqa_primekg.py,sha256=TuIsqcN1Mww3DTqGk6ebgJBWzUWdMWEq2yRQuYSFqvA,4416
|
49
|
+
aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_embeddings.py,sha256=uYFoE_6zeU10_1mLLAHUr5c4S2XZMSc0Q_860o-KWEw,1517
|
50
|
+
aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_huggingface.py,sha256=EINWyXg_3AMHF3WzFLhIUiFDuaEhTVHBvVAJr8VtMDg,1624
|
51
|
+
aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_sentencetransformer.py,sha256=Qxo6WeIDRy8aLh1tNKw0kSlzmUj3MtTak63oW2YwB24,1327
|
32
52
|
aiagents4pharma/talk2knowledgegraphs/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
33
53
|
aiagents4pharma/talk2knowledgegraphs/utils/kg_utils.py,sha256=6vQnPkeOWae_8jePjhma3sJuMTngy0I0tqzdFt6OqKg,2507
|
34
54
|
aiagents4pharma/talk2knowledgegraphs/utils/embeddings/__init__.py,sha256=xRb0x7SoAb0nSVZYgjrqxWvENOMDuqIdL43NMjoOaCs,153
|
35
55
|
aiagents4pharma/talk2knowledgegraphs/utils/embeddings/embeddings.py,sha256=1nGznrAj-xT0xuSMBGz2dOujJ7M_IwSR84njxtxsy9A,2523
|
36
56
|
aiagents4pharma/talk2knowledgegraphs/utils/embeddings/huggingface.py,sha256=2vi_elf6EgzfagFAO5QnL3a_aXZyN7B1EBziu44MTfM,3806
|
37
57
|
aiagents4pharma/talk2knowledgegraphs/utils/embeddings/sentence_transformer.py,sha256=36iKlisOpMtGR5xfTAlSHXWvPqVC_Jbezod8kbBBMVg,2136
|
38
|
-
aiagents4pharma-1.8.
|
39
|
-
aiagents4pharma-1.8.
|
40
|
-
aiagents4pharma-1.8.
|
41
|
-
aiagents4pharma-1.8.
|
42
|
-
aiagents4pharma-1.8.
|
58
|
+
aiagents4pharma-1.8.3.dist-info/LICENSE,sha256=IcIbyB1Hyk5ZDah03VNQvJkbNk2hkBCDqQ8qtnCvB4Q,1077
|
59
|
+
aiagents4pharma-1.8.3.dist-info/METADATA,sha256=qIsojNDT5C2dKHbTLBc7NfcCUyin52GWeP3G0GfU-mE,7988
|
60
|
+
aiagents4pharma-1.8.3.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
61
|
+
aiagents4pharma-1.8.3.dist-info/top_level.txt,sha256=-AH8rMmrSnJtq7HaAObS78UU-cTCwvX660dSxeM7a0A,16
|
62
|
+
aiagents4pharma-1.8.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|