aiagents4pharma 1.8.2__py3-none-any.whl → 1.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. aiagents4pharma/__init__.py +1 -0
  2. aiagents4pharma/configs/__init__.py +5 -0
  3. aiagents4pharma/configs/config.yaml +3 -0
  4. aiagents4pharma/configs/talk2biomodels/__init__.py +5 -0
  5. aiagents4pharma/configs/talk2biomodels/agents/__init__.py +5 -0
  6. aiagents4pharma/configs/talk2biomodels/agents/t2b_agent/__init__.py +3 -0
  7. aiagents4pharma/configs/talk2biomodels/agents/t2b_agent/default.yaml +8 -0
  8. aiagents4pharma/talk2biomodels/__init__.py +1 -1
  9. aiagents4pharma/talk2biomodels/agents/t2b_agent.py +3 -3
  10. aiagents4pharma/talk2biomodels/states/state_talk2biomodels.py +1 -1
  11. aiagents4pharma/talk2biomodels/tests/__init__.py +3 -0
  12. aiagents4pharma/talk2biomodels/tests/test_basico_model.py +55 -0
  13. aiagents4pharma/talk2biomodels/tests/test_langgraph.py +240 -0
  14. aiagents4pharma/talk2biomodels/tests/test_sys_bio_model.py +57 -0
  15. aiagents4pharma/talk2biomodels/tools/ask_question.py +16 -7
  16. aiagents4pharma/talk2biomodels/tools/custom_plotter.py +20 -14
  17. aiagents4pharma/talk2biomodels/tools/get_modelinfo.py +6 -6
  18. aiagents4pharma/talk2biomodels/tools/simulate_model.py +26 -12
  19. aiagents4pharma/talk2cells/tests/scp_agent/test_scp_agent.py +23 -0
  20. aiagents4pharma/talk2competitors/__init__.py +0 -0
  21. aiagents4pharma/talk2knowledgegraphs/tests/__init__.py +0 -0
  22. aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_biobridge_primekg.py +242 -0
  23. aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_dataset.py +29 -0
  24. aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_primekg.py +73 -0
  25. aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_starkqa_primekg.py +116 -0
  26. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_embeddings.py +47 -0
  27. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_huggingface.py +45 -0
  28. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_sentencetransformer.py +40 -0
  29. {aiagents4pharma-1.8.2.dist-info → aiagents4pharma-1.9.0.dist-info}/METADATA +1 -1
  30. aiagents4pharma-1.9.0.dist-info/RECORD +62 -0
  31. aiagents4pharma-1.8.2.dist-info/RECORD +0 -42
  32. {aiagents4pharma-1.8.2.dist-info → aiagents4pharma-1.9.0.dist-info}/LICENSE +0 -0
  33. {aiagents4pharma-1.8.2.dist-info → aiagents4pharma-1.9.0.dist-info}/WHEEL +0 -0
  34. {aiagents4pharma-1.8.2.dist-info → aiagents4pharma-1.9.0.dist-info}/top_level.txt +0 -0
@@ -52,10 +52,10 @@ class TimeSpeciesNameConcentration:
52
52
  class RecurringData:
53
53
  """
54
54
  Dataclass for storing the species and time data
55
- on recurring basis.
55
+ on reocurring basis.
56
56
  """
57
57
  data: List[TimeSpeciesNameConcentration] = Field(
58
- description="species and time data on recurring basis",
58
+ description="species and time data on reocurring basis",
59
59
  default=None)
60
60
 
61
61
  @dataclass
@@ -68,12 +68,15 @@ class ArgumentData:
68
68
  description="species name and initial concentration data",
69
69
  default=None)
70
70
  recurring_data: RecurringData = Field(
71
- description="species and time data on recurring basis",
71
+ description="species and time data on reocurring basis",
72
72
  default=None)
73
+ simulation_name: str = Field(
74
+ description="""An AI assigned `_` separated name of
75
+ the simulation based on human query""")
73
76
 
74
77
  def add_rec_events(model_object, recurring_data):
75
78
  """
76
- Add recurring events to the model.
79
+ Add reocurring events to the model.
77
80
  """
78
81
  for row in recurring_data.data:
79
82
  tp, sn, sc = row.time, row.species_name, row.species_concentration
@@ -86,9 +89,12 @@ class SimulateModelInput(BaseModel):
86
89
  """
87
90
  Input schema for the SimulateModel tool.
88
91
  """
89
- sys_bio_model: ModelData = Field(description="model data", default=None)
90
- arg_data: ArgumentData = Field(description="time, species, and recurring data",
91
- default=None)
92
+ sys_bio_model: ModelData = Field(description="model data",
93
+ default=None)
94
+ arg_data: ArgumentData = Field(description=
95
+ """time, species, and reocurring data
96
+ as well as the simulation name""",
97
+ default=None)
92
98
  tool_call_id: Annotated[str, InjectedToolCallId]
93
99
  state: Annotated[dict, InjectedState]
94
100
 
@@ -153,12 +159,20 @@ class SimulateModelTool(BaseTool):
153
159
  interval=interval
154
160
  )
155
161
 
162
+ dic_simulated_data = {
163
+ 'name': arg_data.simulation_name,
164
+ 'source': sys_bio_model.biomodel_id if sys_bio_model.biomodel_id else 'upload',
165
+ 'tool_call_id': tool_call_id,
166
+ 'data': df.to_dict()
167
+ }
168
+
156
169
  # Prepare the dictionary of updated state for the model
157
170
  dic_updated_state_for_model = {}
158
171
  for key, value in {
159
- "model_id": [sys_bio_model.biomodel_id],
160
- "sbml_file_path": [sbml_file_path],
161
- }.items():
172
+ "model_id": [sys_bio_model.biomodel_id],
173
+ "sbml_file_path": [sbml_file_path],
174
+ "dic_simulated_data": [dic_simulated_data],
175
+ }.items():
162
176
  if value:
163
177
  dic_updated_state_for_model[key] = value
164
178
 
@@ -166,11 +180,11 @@ class SimulateModelTool(BaseTool):
166
180
  return Command(
167
181
  update=dic_updated_state_for_model|{
168
182
  # update the state keys
169
- "dic_simulated_data": df.to_dict(),
183
+ # "dic_simulated_data": df.to_dict(),
170
184
  # update the message history
171
185
  "messages": [
172
186
  ToolMessage(
173
- content="Simulation results are ready.",
187
+ content=f"Simulation results of {arg_data.simulation_name}",
174
188
  tool_call_id=tool_call_id
175
189
  )
176
190
  ],
@@ -0,0 +1,23 @@
1
+ '''
2
+ Test cases for the search_studies
3
+ '''
4
+
5
+ # from ..tools.search_studies import search_studies
6
+ from aiagents4pharma.talk2cells.agents.scp_agent import get_app
7
+ from langchain_core.messages import HumanMessage
8
+
9
+ def test_agent_scp():
10
+ '''
11
+ Test the agent_scp.
12
+ '''
13
+ unique_id = 12345
14
+ app = get_app(unique_id)
15
+ config = {"configurable": {"thread_id": unique_id}}
16
+ prompt = "Search for studies on Crohns Disease."
17
+ response = app.invoke(
18
+ {"messages": [HumanMessage(content=prompt)]},
19
+ config=config
20
+ )
21
+ assistant_msg = response["messages"][-1].content
22
+ # Check if the assistant message is a string
23
+ assert isinstance(assistant_msg, str)
File without changes
File without changes
@@ -0,0 +1,242 @@
1
+ """
2
+ Test cases for datasets/primekg_loader.py
3
+ """
4
+
5
+ import os
6
+ import shutil
7
+ import pytest
8
+ from ..datasets.biobridge_primekg import BioBridgePrimeKG
9
+
10
+ # Remove the data folder for testing if it exists
11
+ PRIMEKG_LOCAL_DIR = "../data/primekg_test/"
12
+ LOCAL_DIR = "../data/biobridge_primekg_test/"
13
+ shutil.rmtree(LOCAL_DIR, ignore_errors=True)
14
+
15
+ @pytest.fixture(name="biobridge_primekg")
16
+ def biobridge_primekg_fixture():
17
+ """
18
+ Fixture for creating an instance of PrimeKG.
19
+ """
20
+ return BioBridgePrimeKG(primekg_dir=PRIMEKG_LOCAL_DIR,
21
+ local_dir=LOCAL_DIR)
22
+
23
+ def test_download_primekg(biobridge_primekg):
24
+ """
25
+ Test the loading method of the BioBridge-PrimeKG class by downloading data from repository.
26
+ """
27
+ # Load BioBridge-PrimeKG data
28
+ biobridge_primekg.load_data()
29
+ primekg_nodes = biobridge_primekg.get_primekg().get_nodes()
30
+ primekg_edges = biobridge_primekg.get_primekg().get_edges()
31
+ biobridge_data_config = biobridge_primekg.get_data_config()
32
+ biobridge_emb_dict = biobridge_primekg.get_node_embeddings()
33
+ biobridge_triplets = biobridge_primekg.get_primekg_triplets()
34
+ biobridge_splits = biobridge_primekg.get_train_test_split()
35
+ biobridge_node_info = biobridge_primekg.get_node_info_dict()
36
+
37
+ # Check if the local directories exists
38
+ assert os.path.exists(biobridge_primekg.primekg_dir)
39
+ assert os.path.exists(biobridge_primekg.local_dir)
40
+ # Check if downloaded and processed files exist
41
+ # PrimeKG files
42
+ files = ["nodes.tab", "primekg_nodes.tsv.gz",
43
+ "edges.csv", "primekg_edges.tsv.gz"]
44
+ for file in files:
45
+ path = f"{biobridge_primekg.primekg_dir}/{file}"
46
+ assert os.path.exists(path)
47
+ # BioBridge data config
48
+ assert os.path.exists(f"{biobridge_primekg.local_dir}/data_config.json")
49
+ # BioBridge embeddings
50
+ files = [
51
+ "protein.pkl",
52
+ "mf.pkl",
53
+ "cc.pkl",
54
+ "bp.pkl",
55
+ "drug.pkl",
56
+ "disease.pkl",
57
+ "embedding_dict.pkl"
58
+ ]
59
+ for file in files:
60
+ path = f"{biobridge_primekg.local_dir}/embeddings/{file}"
61
+ assert os.path.exists(path)
62
+ # BioBridge processed files
63
+ files = [
64
+ "protein.csv",
65
+ "mf.csv",
66
+ "cc.csv",
67
+ "bp.csv",
68
+ "drug.csv",
69
+ "disease.csv",
70
+ "triplet_full.tsv.gz",
71
+ "triplet_full_altered.tsv.gz",
72
+ "node_train.tsv.gz",
73
+ "triplet_train.tsv.gz",
74
+ "node_test.tsv.gz",
75
+ "triplet_test.tsv.gz",
76
+ ]
77
+ for file in files:
78
+ path = f"{biobridge_primekg.local_dir}/processed/{file}"
79
+ assert os.path.exists(path)
80
+ # Check processed PrimeKG dataframes
81
+ # Nodes
82
+ assert primekg_nodes is not None
83
+ assert len(primekg_nodes) > 0
84
+ assert primekg_nodes.shape[0] == 129375
85
+ # Edges
86
+ assert primekg_edges is not None
87
+ assert len(primekg_edges) > 0
88
+ assert primekg_edges.shape[0] == 8100498
89
+ # Check processed BioBridge data config
90
+ assert biobridge_data_config is not None
91
+ assert len(biobridge_data_config) > 0
92
+ assert len(biobridge_data_config['node_type']) == 10
93
+ assert len(biobridge_data_config['relation_type']) == 18
94
+ assert len(biobridge_data_config['emb_dim']) == 6
95
+ # Check processed BioBridge embeddings
96
+ assert biobridge_emb_dict is not None
97
+ assert len(biobridge_emb_dict) > 0
98
+ assert len(biobridge_emb_dict) == 85466
99
+ # Check processed BioBridge triplets
100
+ assert biobridge_triplets is not None
101
+ assert len(biobridge_triplets) > 0
102
+ assert biobridge_triplets.shape[0] == 3904610
103
+ assert list(biobridge_splits.keys()) == ['train', 'node_train', 'test', 'node_test']
104
+ assert len(biobridge_splits['train']) == 3510930
105
+ assert len(biobridge_splits['node_train']) == 76486
106
+ assert len(biobridge_splits['test']) == 393680
107
+ assert len(biobridge_splits['node_test']) == 8495
108
+ # Check node info dictionary
109
+ assert list(biobridge_node_info.keys()) == ['gene/protein',
110
+ 'molecular_function',
111
+ 'cellular_component',
112
+ 'biological_process',
113
+ 'drug',
114
+ 'disease']
115
+ assert len(biobridge_node_info['gene/protein']) == 19162
116
+ assert len(biobridge_node_info['molecular_function']) == 10966
117
+ assert len(biobridge_node_info['cellular_component']) == 4013
118
+ assert len(biobridge_node_info['biological_process']) == 27478
119
+ assert len(biobridge_node_info['drug']) == 6948
120
+ assert len(biobridge_node_info['disease']) == 44133
121
+
122
+
123
+ def test_load_existing_primekg(biobridge_primekg):
124
+ """
125
+ Test the loading method of the BioBridge-PrimeKG class by loading existing data in local.
126
+ """
127
+ # Load BioBridge-PrimeKG data
128
+ biobridge_primekg.load_data()
129
+ primekg_nodes = biobridge_primekg.get_primekg().get_nodes()
130
+ primekg_edges = biobridge_primekg.get_primekg().get_edges()
131
+ biobridge_data_config = biobridge_primekg.get_data_config()
132
+ biobridge_emb_dict = biobridge_primekg.get_node_embeddings()
133
+ biobridge_triplets = biobridge_primekg.get_primekg_triplets()
134
+ biobridge_splits = biobridge_primekg.get_train_test_split()
135
+ biobridge_node_info = biobridge_primekg.get_node_info_dict()
136
+
137
+ # Check if the local directories exists
138
+ assert os.path.exists(biobridge_primekg.primekg_dir)
139
+ assert os.path.exists(biobridge_primekg.local_dir)
140
+ # Check if downloaded and processed files exist
141
+ # PrimeKG files
142
+ files = ["nodes.tab", "primekg_nodes.tsv.gz",
143
+ "edges.csv", "primekg_edges.tsv.gz"]
144
+ for file in files:
145
+ path = f"{biobridge_primekg.primekg_dir}/{file}"
146
+ assert os.path.exists(path)
147
+ # BioBridge data config
148
+ assert os.path.exists(f"{biobridge_primekg.local_dir}/data_config.json")
149
+ # BioBridge embeddings
150
+ files = [
151
+ "protein.pkl",
152
+ "mf.pkl",
153
+ "cc.pkl",
154
+ "bp.pkl",
155
+ "drug.pkl",
156
+ "disease.pkl",
157
+ "embedding_dict.pkl"
158
+ ]
159
+ for file in files:
160
+ path = f"{biobridge_primekg.local_dir}/embeddings/{file}"
161
+ assert os.path.exists(path)
162
+ # BioBridge processed files
163
+ files = [
164
+ "protein.csv",
165
+ "mf.csv",
166
+ "cc.csv",
167
+ "bp.csv",
168
+ "drug.csv",
169
+ "disease.csv",
170
+ "triplet_full.tsv.gz",
171
+ "triplet_full_altered.tsv.gz",
172
+ "node_train.tsv.gz",
173
+ "triplet_train.tsv.gz",
174
+ "node_test.tsv.gz",
175
+ "triplet_test.tsv.gz",
176
+ ]
177
+ for file in files:
178
+ path = f"{biobridge_primekg.local_dir}/processed/{file}"
179
+ assert os.path.exists(path)
180
+ # Check processed PrimeKG dataframes
181
+ # Nodes
182
+ assert primekg_nodes is not None
183
+ assert len(primekg_nodes) > 0
184
+ assert primekg_nodes.shape[0] == 129375
185
+ # Edges
186
+ assert primekg_edges is not None
187
+ assert len(primekg_edges) > 0
188
+ assert primekg_edges.shape[0] == 8100498
189
+ # Check processed BioBridge data config
190
+ assert biobridge_data_config is not None
191
+ assert len(biobridge_data_config) > 0
192
+ assert len(biobridge_data_config['node_type']) == 10
193
+ assert len(biobridge_data_config['relation_type']) == 18
194
+ assert len(biobridge_data_config['emb_dim']) == 6
195
+ # Check processed BioBridge embeddings
196
+ assert biobridge_emb_dict is not None
197
+ assert len(biobridge_emb_dict) > 0
198
+ assert len(biobridge_emb_dict) == 85466
199
+ # Check processed BioBridge triplets
200
+ assert biobridge_triplets is not None
201
+ assert len(biobridge_triplets) > 0
202
+ assert biobridge_triplets.shape[0] == 3904610
203
+ assert list(biobridge_splits.keys()) == ['train', 'node_train', 'test', 'node_test']
204
+ assert len(biobridge_splits['train']) == 3510930
205
+ assert len(biobridge_splits['node_train']) == 76486
206
+ assert len(biobridge_splits['test']) == 393680
207
+ assert len(biobridge_splits['node_test']) == 8495
208
+ # Check node info dictionary
209
+ assert list(biobridge_node_info.keys()) == ['gene/protein',
210
+ 'molecular_function',
211
+ 'cellular_component',
212
+ 'biological_process',
213
+ 'drug',
214
+ 'disease']
215
+ assert len(biobridge_node_info['gene/protein']) == 19162
216
+ assert len(biobridge_node_info['molecular_function']) == 10966
217
+ assert len(biobridge_node_info['cellular_component']) == 4013
218
+ assert len(biobridge_node_info['biological_process']) == 27478
219
+ assert len(biobridge_node_info['drug']) == 6948
220
+ assert len(biobridge_node_info['disease']) == 44133
221
+
222
+ # def test_load_existing_primekg_with_negative_triplets(biobridge_primekg):
223
+ # """
224
+ # Test the loading method of the BioBridge-PrimeKG class by loading existing data in local.
225
+ # In addition, it builds negative triplets for training data.
226
+ # """
227
+ # # Load BioBridge-PrimeKG data
228
+ # # Using 1 negative sample per positive triplet
229
+ # biobridge_primekg.load_data(build_neg_triplest=True, n_neg_samples=1)
230
+ # biobridge_neg_triplets = biobridge_primekg.get_primekg_triplets_negative()
231
+
232
+ # # Check if the local directories exists
233
+ # assert os.path.exists(biobridge_primekg.primekg_dir)
234
+ # assert os.path.exists(biobridge_primekg.local_dir)
235
+ # # Check if downloaded and processed files exist
236
+ # path = f"{biobridge_primekg.local_dir}/processed/triplet_train_negative.tsv.gz"
237
+ # assert os.path.exists(path)
238
+ # # Check processed BioBridge triplets
239
+ # assert biobridge_neg_triplets is not None
240
+ # assert len(biobridge_neg_triplets) > 0
241
+ # assert biobridge_neg_triplets.shape[0] == 3510930
242
+ # assert len(biobridge_neg_triplets.negative_tail_index[0]) == 1
@@ -0,0 +1,29 @@
1
+ """
2
+ Test cases for datasets/dataset.py
3
+ """
4
+
5
+ from ..datasets.dataset import Dataset
6
+
7
+ class MockDataset(Dataset):
8
+ """
9
+ Mock dataset class for testing purposes.
10
+ """
11
+ def setup(self):
12
+ pass
13
+
14
+ def load_data(self):
15
+ pass
16
+
17
+ def test_dataset_setup():
18
+ """
19
+ Test the setup method of the Dataset class.
20
+ """
21
+ dataset = MockDataset()
22
+ assert dataset.setup() is None
23
+
24
+ def test_dataset_load_data():
25
+ """
26
+ Test the load_data method of the Dataset class.
27
+ """
28
+ dataset = MockDataset()
29
+ assert dataset.load_data() is None
@@ -0,0 +1,73 @@
1
+ """
2
+ Test cases for datasets/primekg_loader.py
3
+ """
4
+
5
+ import os
6
+ import shutil
7
+ import pytest
8
+ from ..datasets.primekg import PrimeKG
9
+
10
+ # Remove the data folder for testing if it exists
11
+ LOCAL_DIR = "../data/primekg_test/"
12
+ shutil.rmtree(LOCAL_DIR, ignore_errors=True)
13
+
14
+ @pytest.fixture(name="primekg")
15
+ def primekg_fixture():
16
+ """
17
+ Fixture for creating an instance of PrimeKG.
18
+ """
19
+ return PrimeKG(local_dir=LOCAL_DIR)
20
+
21
+ def test_download_primekg(primekg):
22
+ """
23
+ Test the loading method of the PrimeKG class by downloading PrimeKG from server.
24
+ """
25
+ # Load PrimeKG data
26
+ primekg.load_data()
27
+ primekg_nodes = primekg.get_nodes()
28
+ primekg_edges = primekg.get_edges()
29
+
30
+ # Check if the local directory exists
31
+ assert os.path.exists(primekg.local_dir)
32
+ # Check if downloaded and processed files exist
33
+ files = ["nodes.tab", f"{primekg.name}_nodes.tsv.gz",
34
+ "edges.csv", f"{primekg.name}_edges.tsv.gz"]
35
+ for file in files:
36
+ path = f"{primekg.local_dir}/{file}"
37
+ assert os.path.exists(path)
38
+ # Check processed PrimeKG dataframes
39
+ # Nodes
40
+ assert primekg_nodes is not None
41
+ assert len(primekg_nodes) > 0
42
+ assert primekg_nodes.shape[0] == 129375
43
+ # Edges
44
+ assert primekg_edges is not None
45
+ assert len(primekg_edges) > 0
46
+ assert primekg_edges.shape[0] == 8100498
47
+
48
+ def test_load_existing_primekg(primekg):
49
+ """
50
+ Test the loading method of the PrimeKG class by loading existing PrimeKG in local.
51
+ """
52
+ # Load PrimeKG data
53
+ primekg.load_data()
54
+ primekg_nodes = primekg.get_nodes()
55
+ primekg_edges = primekg.get_edges()
56
+
57
+ # Check if the local directory exists
58
+ assert os.path.exists(primekg.local_dir)
59
+ # Check if downloaded and processed files exist
60
+ files = ["nodes.tab", f"{primekg.name}_nodes.tsv.gz",
61
+ "edges.csv", f"{primekg.name}_edges.tsv.gz"]
62
+ for file in files:
63
+ path = f"{primekg.local_dir}/{file}"
64
+ assert os.path.exists(path)
65
+ # Check processed PrimeKG dataframes
66
+ # Nodes
67
+ assert primekg_nodes is not None
68
+ assert len(primekg_nodes) > 0
69
+ assert primekg_nodes.shape[0] == 129375
70
+ # Edges
71
+ assert primekg_edges is not None
72
+ assert len(primekg_edges) > 0
73
+ assert primekg_edges.shape[0] == 8100498
@@ -0,0 +1,116 @@
1
+ """
2
+ Test cases for datasets/starkqa_primekg_loader.py
3
+ """
4
+
5
+ import os
6
+ import shutil
7
+ import pytest
8
+ from ..datasets.starkqa_primekg import StarkQAPrimeKG
9
+
10
+ # Remove the data folder for testing if it exists
11
+ LOCAL_DIR = "../data/starkqa_primekg_test/"
12
+ shutil.rmtree(LOCAL_DIR, ignore_errors=True)
13
+
14
+ @pytest.fixture(name="starkqa_primekg")
15
+ def starkqa_primekg_fixture():
16
+ """
17
+ Fixture for creating an instance of StarkQAPrimeKGData.
18
+ """
19
+ return StarkQAPrimeKG(local_dir=LOCAL_DIR)
20
+
21
+ def test_download_starkqa_primekg(starkqa_primekg):
22
+ """
23
+ Test the loading method of the StarkQAPrimeKGLoaderTool class by downloading files
24
+ from HuggingFace Hub.
25
+ """
26
+ # Load StarkQA PrimeKG data
27
+ starkqa_primekg.load_data()
28
+ starkqa_df = starkqa_primekg.get_starkqa()
29
+ primekg_node_info = starkqa_primekg.get_starkqa_node_info()
30
+ split_idx = starkqa_primekg.get_starkqa_split_indicies()
31
+ query_embeddings = starkqa_primekg.get_query_embeddings()
32
+ node_embeddings = starkqa_primekg.get_node_embeddings()
33
+
34
+ # Check if the local directory exists
35
+ assert os.path.exists(starkqa_primekg.local_dir)
36
+ # Check if downloaded files exist in the local directory
37
+ files = ['qa/prime/split/test-0.1.index',
38
+ 'qa/prime/split/test.index',
39
+ 'qa/prime/split/train.index',
40
+ 'qa/prime/split/val.index',
41
+ 'qa/prime/stark_qa/stark_qa.csv',
42
+ 'qa/prime/stark_qa/stark_qa_human_generated_eval.csv',
43
+ 'skb/prime/processed.zip']
44
+ for file in files:
45
+ path = f"{starkqa_primekg.local_dir}/{file}"
46
+ assert os.path.exists(path)
47
+ # Check dataframe
48
+ assert starkqa_df is not None
49
+ assert len(starkqa_df) > 0
50
+ assert starkqa_df.shape[0] == 11204
51
+ # Check node information
52
+ assert primekg_node_info is not None
53
+ assert len(primekg_node_info) == 129375
54
+ # Check split indices
55
+ assert list(split_idx.keys()) == ['train', 'val', 'test', 'test-0.1']
56
+ assert len(split_idx['train']) == 6162
57
+ assert len(split_idx['val']) == 2241
58
+ assert len(split_idx['test']) == 2801
59
+ assert len(split_idx['test-0.1']) == 280
60
+ # Check query embeddings
61
+ assert query_embeddings is not None
62
+ assert len(query_embeddings) == 11204
63
+ assert query_embeddings[0].shape[1] == 1536
64
+ # Check node embeddings
65
+ assert node_embeddings is not None
66
+ assert len(node_embeddings) == 129375
67
+ assert node_embeddings[0].shape[1] == 1536
68
+
69
+ def test_load_existing_starkqa_primekg(starkqa_primekg):
70
+ """
71
+
72
+ Test the loading method of the StarkQAPrimeKGLoaderTool class by loading existing files
73
+ in the local directory.
74
+ """
75
+ # Load StarkQA PrimeKG data
76
+ starkqa_primekg.load_data()
77
+ starkqa_df = starkqa_primekg.get_starkqa()
78
+ primekg_node_info = starkqa_primekg.get_starkqa_node_info()
79
+ split_idx = starkqa_primekg.get_starkqa_split_indicies()
80
+ query_embeddings = starkqa_primekg.get_query_embeddings()
81
+ node_embeddings = starkqa_primekg.get_node_embeddings()
82
+
83
+ # Check if the local directory exists
84
+ assert os.path.exists(starkqa_primekg.local_dir)
85
+ # Check if downloaded and processed files exist
86
+ files = ['qa/prime/split/test-0.1.index',
87
+ 'qa/prime/split/test.index',
88
+ 'qa/prime/split/train.index',
89
+ 'qa/prime/split/val.index',
90
+ 'qa/prime/stark_qa/stark_qa.csv',
91
+ 'qa/prime/stark_qa/stark_qa_human_generated_eval.csv',
92
+ 'skb/prime/processed.zip']
93
+ for file in files:
94
+ path = f"{starkqa_primekg.local_dir}/{file}"
95
+ assert os.path.exists(path)
96
+ # Check dataframe
97
+ assert starkqa_df is not None
98
+ assert len(starkqa_df) > 0
99
+ assert starkqa_df.shape[0] == 11204
100
+ # Check node information
101
+ assert primekg_node_info is not None
102
+ assert len(primekg_node_info) == 129375
103
+ # Check split indices
104
+ assert list(split_idx.keys()) == ['train', 'val', 'test', 'test-0.1']
105
+ assert len(split_idx['train']) == 6162
106
+ assert len(split_idx['val']) == 2241
107
+ assert len(split_idx['test']) == 2801
108
+ assert len(split_idx['test-0.1']) == 280
109
+ # Check query embeddings
110
+ assert query_embeddings is not None
111
+ assert len(query_embeddings) == 11204
112
+ assert query_embeddings[0].shape[1] == 1536
113
+ # Check node embeddings
114
+ assert node_embeddings is not None
115
+ assert len(node_embeddings) == 129375
116
+ assert node_embeddings[0].shape[1] == 1536
@@ -0,0 +1,47 @@
1
+ """
2
+ Test cases for utils/embeddings/embeddings.py
3
+ """
4
+
5
+ import pytest
6
+ from ..utils.embeddings.embeddings import Embeddings
7
+
8
+ class TestEmbeddings(Embeddings):
9
+ """Test implementation of the Embeddings interface for testing purposes."""
10
+
11
+ def embed_documents(self, texts: list[str]) -> list[list[float]]:
12
+ return [[0.1, 0.2, 0.3] for _ in texts]
13
+
14
+ def embed_query(self, text: str) -> list[float]:
15
+ return [0.1, 0.2, 0.3]
16
+
17
+ def test_embed_documents():
18
+ """Test embedding documents using the Embeddings interface."""
19
+ embeddings = TestEmbeddings()
20
+ texts = ["text1", "text2"]
21
+ result = embeddings.embed_documents(texts)
22
+ assert result == [[0.1, 0.2, 0.3], [0.1, 0.2, 0.3]]
23
+
24
+
25
+ def test_embed_query():
26
+ """Test embedding a query using the Embeddings interface."""
27
+ embeddings = TestEmbeddings()
28
+ text = "query"
29
+ result = embeddings.embed_query(text)
30
+ assert result == [0.1, 0.2, 0.3]
31
+
32
+ @pytest.mark.asyncio
33
+ async def test_aembed_documents():
34
+ """Test asynchronous embedding of documents using the Embeddings interface."""
35
+ embeddings = TestEmbeddings()
36
+ texts = ["text1", "text2"]
37
+ result = await embeddings.aembed_documents(texts)
38
+ assert result == [[0.1, 0.2, 0.3], [0.1, 0.2, 0.3]]
39
+
40
+
41
+ @pytest.mark.asyncio
42
+ async def test_aembed_query():
43
+ """Test asynchronous embedding of a query using the Embeddings interface."""
44
+ embeddings = TestEmbeddings()
45
+ text = "query"
46
+ result = await embeddings.aembed_query(text)
47
+ assert result == [0.1, 0.2, 0.3]
@@ -0,0 +1,45 @@
1
+ """
2
+ Test cases for utils/embeddings/huggingface.py
3
+ """
4
+
5
+ import pytest
6
+ from ..utils.embeddings.huggingface import EmbeddingWithHuggingFace
7
+
8
+ @pytest.fixture(name="embedding_model")
9
+ def embedding_model_fixture():
10
+ """Return the configuration object for the HuggingFace embedding model and model object"""
11
+ return EmbeddingWithHuggingFace(
12
+ model_name="NeuML/pubmedbert-base-embeddings",
13
+ model_cache_dir="../../cache",
14
+ truncation=True,
15
+ )
16
+
17
+ def test_embedding_with_huggingface_embed_documents(embedding_model):
18
+ """Test embedding documents using the EmbeddingWithHuggingFace class."""
19
+ # Perform embedding
20
+ texts = ["Adalimumab", "Infliximab", "Vedolizumab"]
21
+ result = embedding_model.embed_documents(texts)
22
+ # Check the result
23
+ assert len(result) == 3
24
+ assert len(result[0]) == 768
25
+
26
+ def test_embedding_with_huggingface_embed_query(embedding_model):
27
+ """Test embedding a query using the EmbeddingWithHuggingFace class."""
28
+ # Perform embedding
29
+ text = "Adalimumab"
30
+ result = embedding_model.embed_query(text)
31
+ # Check the result
32
+ assert len(result) == 768
33
+
34
+
35
+ def test_embedding_with_huggingface_failed():
36
+ """Test embedding documents using the EmbeddingWithHuggingFace class."""
37
+ # Check if the model is available on HuggingFace Hub
38
+ model_name = "aiagents4pharma/embeddings"
39
+ err_msg = f"Model {model_name} is not available on HuggingFace Hub."
40
+ with pytest.raises(ValueError, match=err_msg):
41
+ EmbeddingWithHuggingFace(
42
+ model_name=model_name,
43
+ model_cache_dir="../../cache",
44
+ truncation=True,
45
+ )