aiagents4pharma 1.8.0__py3-none-any.whl → 1.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. aiagents4pharma/__init__.py +9 -5
  2. aiagents4pharma/configs/__init__.py +5 -0
  3. aiagents4pharma/configs/config.yaml +4 -0
  4. aiagents4pharma/configs/talk2biomodels/__init__.py +6 -0
  5. aiagents4pharma/configs/talk2biomodels/agents/__init__.py +5 -0
  6. aiagents4pharma/configs/talk2biomodels/agents/t2b_agent/__init__.py +3 -0
  7. aiagents4pharma/configs/talk2biomodels/agents/t2b_agent/default.yaml +14 -0
  8. aiagents4pharma/configs/talk2biomodels/tools/__init__.py +4 -0
  9. aiagents4pharma/configs/talk2biomodels/tools/ask_question/__init__.py +3 -0
  10. aiagents4pharma/talk2biomodels/__init__.py +3 -0
  11. aiagents4pharma/talk2biomodels/agents/__init__.py +5 -0
  12. aiagents4pharma/talk2biomodels/agents/t2b_agent.py +96 -0
  13. aiagents4pharma/talk2biomodels/api/__init__.py +6 -0
  14. aiagents4pharma/talk2biomodels/api/kegg.py +83 -0
  15. aiagents4pharma/talk2biomodels/api/ols.py +72 -0
  16. aiagents4pharma/talk2biomodels/api/uniprot.py +35 -0
  17. aiagents4pharma/talk2biomodels/models/basico_model.py +29 -32
  18. aiagents4pharma/talk2biomodels/models/sys_bio_model.py +9 -6
  19. aiagents4pharma/talk2biomodels/states/__init__.py +5 -0
  20. aiagents4pharma/talk2biomodels/states/state_talk2biomodels.py +41 -0
  21. aiagents4pharma/talk2biomodels/tests/__init__.py +3 -0
  22. aiagents4pharma/talk2biomodels/tests/test_api.py +57 -0
  23. aiagents4pharma/talk2biomodels/tests/test_ask_question.py +44 -0
  24. aiagents4pharma/talk2biomodels/tests/test_basico_model.py +54 -0
  25. aiagents4pharma/talk2biomodels/tests/test_get_annotation.py +171 -0
  26. aiagents4pharma/talk2biomodels/tests/test_getmodelinfo.py +26 -0
  27. aiagents4pharma/talk2biomodels/tests/test_integration.py +126 -0
  28. aiagents4pharma/talk2biomodels/tests/test_param_scan.py +68 -0
  29. aiagents4pharma/talk2biomodels/tests/test_query_article.py +76 -0
  30. aiagents4pharma/talk2biomodels/tests/test_search_models.py +28 -0
  31. aiagents4pharma/talk2biomodels/tests/test_simulate_model.py +39 -0
  32. aiagents4pharma/talk2biomodels/tests/test_steady_state.py +90 -0
  33. aiagents4pharma/talk2biomodels/tests/test_sys_bio_model.py +63 -0
  34. aiagents4pharma/talk2biomodels/tools/__init__.py +5 -0
  35. aiagents4pharma/talk2biomodels/tools/ask_question.py +61 -18
  36. aiagents4pharma/talk2biomodels/tools/custom_plotter.py +20 -14
  37. aiagents4pharma/talk2biomodels/tools/get_annotation.py +304 -0
  38. aiagents4pharma/talk2biomodels/tools/get_modelinfo.py +11 -9
  39. aiagents4pharma/talk2biomodels/tools/load_arguments.py +114 -0
  40. aiagents4pharma/talk2biomodels/tools/load_biomodel.py +0 -1
  41. aiagents4pharma/talk2biomodels/tools/parameter_scan.py +287 -0
  42. aiagents4pharma/talk2biomodels/tools/query_article.py +59 -0
  43. aiagents4pharma/talk2biomodels/tools/simulate_model.py +35 -90
  44. aiagents4pharma/talk2biomodels/tools/steady_state.py +167 -0
  45. aiagents4pharma/talk2cells/tests/scp_agent/test_scp_agent.py +23 -0
  46. aiagents4pharma/talk2cells/tools/scp_agent/__init__.py +6 -0
  47. aiagents4pharma/talk2cells/tools/scp_agent/display_studies.py +25 -0
  48. aiagents4pharma/talk2cells/tools/scp_agent/search_studies.py +79 -0
  49. aiagents4pharma/talk2competitors/__init__.py +5 -0
  50. aiagents4pharma/talk2competitors/agents/__init__.py +6 -0
  51. aiagents4pharma/talk2competitors/agents/main_agent.py +130 -0
  52. aiagents4pharma/talk2competitors/agents/s2_agent.py +75 -0
  53. aiagents4pharma/talk2competitors/config/__init__.py +5 -0
  54. aiagents4pharma/talk2competitors/config/config.py +110 -0
  55. aiagents4pharma/talk2competitors/state/__init__.py +5 -0
  56. aiagents4pharma/talk2competitors/state/state_talk2competitors.py +32 -0
  57. aiagents4pharma/talk2competitors/tests/__init__.py +3 -0
  58. aiagents4pharma/talk2competitors/tests/test_langgraph.py +274 -0
  59. aiagents4pharma/talk2competitors/tools/__init__.py +7 -0
  60. aiagents4pharma/talk2competitors/tools/s2/__init__.py +8 -0
  61. aiagents4pharma/talk2competitors/tools/s2/display_results.py +25 -0
  62. aiagents4pharma/talk2competitors/tools/s2/multi_paper_rec.py +132 -0
  63. aiagents4pharma/talk2competitors/tools/s2/search.py +119 -0
  64. aiagents4pharma/talk2competitors/tools/s2/single_paper_rec.py +141 -0
  65. aiagents4pharma/talk2knowledgegraphs/__init__.py +2 -1
  66. aiagents4pharma/talk2knowledgegraphs/tests/__init__.py +0 -0
  67. aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_biobridge_primekg.py +242 -0
  68. aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_dataset.py +29 -0
  69. aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_primekg.py +73 -0
  70. aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_starkqa_primekg.py +116 -0
  71. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_embeddings.py +47 -0
  72. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_huggingface.py +45 -0
  73. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_sentencetransformer.py +40 -0
  74. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_enrichments.py +39 -0
  75. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ollama.py +117 -0
  76. aiagents4pharma/talk2knowledgegraphs/utils/__init__.py +5 -0
  77. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/__init__.py +5 -0
  78. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/enrichments.py +36 -0
  79. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/ollama.py +123 -0
  80. {aiagents4pharma-1.8.0.dist-info → aiagents4pharma-1.15.0.dist-info}/METADATA +44 -25
  81. aiagents4pharma-1.15.0.dist-info/RECORD +102 -0
  82. aiagents4pharma-1.8.0.dist-info/RECORD +0 -35
  83. {aiagents4pharma-1.8.0.dist-info → aiagents4pharma-1.15.0.dist-info}/LICENSE +0 -0
  84. {aiagents4pharma-1.8.0.dist-info → aiagents4pharma-1.15.0.dist-info}/WHEEL +0 -0
  85. {aiagents4pharma-1.8.0.dist-info → aiagents4pharma-1.15.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,141 @@
1
+ #!/usr/bin/env python3
2
+
3
+ """
4
+ This tool is used to return recommendations for a single paper.
5
+ """
6
+
7
+ import logging
8
+ from typing import Annotated, Any, Dict, Optional
9
+
10
+ import pandas as pd
11
+ import requests
12
+ from langchain_core.messages import ToolMessage
13
+ from langchain_core.tools import tool
14
+ from langchain_core.tools.base import InjectedToolCallId
15
+ from langgraph.types import Command
16
+ from pydantic import BaseModel, Field
17
+
18
+ # Configure logging
19
+ logging.basicConfig(level=logging.INFO)
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class SinglePaperRecInput(BaseModel):
24
+ """Input schema for single paper recommendation tool."""
25
+
26
+ paper_id: str = Field(
27
+ description="Semantic Scholar Paper ID to get recommendations for (40-character string)"
28
+ )
29
+ limit: int = Field(
30
+ default=2,
31
+ description="Maximum number of recommendations to return",
32
+ ge=1,
33
+ le=500,
34
+ )
35
+ year: Optional[str] = Field(
36
+ default=None,
37
+ description="Year range in format: YYYY for specific year, "
38
+ "YYYY- for papers after year, -YYYY for papers before year, or YYYY:YYYY for range",
39
+ )
40
+ tool_call_id: Annotated[str, InjectedToolCallId]
41
+ model_config = {"arbitrary_types_allowed": True}
42
+
43
+
44
+ @tool(args_schema=SinglePaperRecInput)
45
+ def get_single_paper_recommendations(
46
+ paper_id: str,
47
+ tool_call_id: Annotated[str, InjectedToolCallId],
48
+ limit: int = 2,
49
+ year: Optional[str] = None,
50
+ ) -> Dict[str, Any]:
51
+ """
52
+ Get paper recommendations based on a single paper.
53
+
54
+ Args:
55
+ paper_id (str): The Semantic Scholar Paper ID to get recommendations for.
56
+ tool_call_id (Annotated[str, InjectedToolCallId]): The tool call ID.
57
+ limit (int, optional): The maximum number of recommendations to return. Defaults to 2.
58
+ year (str, optional): Year range for papers.
59
+ Supports formats like "2024-", "-2024", "2024:2025". Defaults to None.
60
+
61
+ Returns:
62
+ Dict[str, Any]: The recommendations and related information.
63
+ """
64
+ logger.info("Starting single paper recommendations search.")
65
+
66
+ endpoint = (
67
+ f"https://api.semanticscholar.org/recommendations/v1/papers/forpaper/{paper_id}"
68
+ )
69
+ params = {
70
+ "limit": min(limit, 500), # Max 500 per API docs
71
+ "fields": "paperId,title,abstract,year,authors,citationCount,url",
72
+ "from": "all-cs", # Using all-cs pool as specified in docs
73
+ }
74
+
75
+ # Add year parameter if provided
76
+ if year:
77
+ params["year"] = year
78
+
79
+ response = requests.get(endpoint, params=params, timeout=10)
80
+ data = response.json()
81
+ papers = data.get("data", [])
82
+ response = requests.get(endpoint, params=params, timeout=10)
83
+ # print(f"API Response Status: {response.status_code}")
84
+ logging.info(
85
+ "API Response Status for recommendations of paper %s: %s",
86
+ paper_id,
87
+ response.status_code,
88
+ )
89
+ # print(f"Request params: {params}")
90
+ logging.info("Request params: %s", params)
91
+
92
+ data = response.json()
93
+ recommendations = data.get("recommendedPapers", [])
94
+
95
+ # Extract paper ID and title from recommendations
96
+ filtered_papers = {
97
+ paper["paperId"]: {
98
+ "Title": paper.get("title", "N/A"),
99
+ "Abstract": paper.get("abstract", "N/A"),
100
+ "Year": paper.get("year", "N/A"),
101
+ "Citation Count": paper.get("citationCount", "N/A"),
102
+ "URL": paper.get("url", "N/A"),
103
+ # "Publication Type": paper.get("publicationTypes", ["N/A"])[0]
104
+ # if paper.get("publicationTypes")
105
+ # else "N/A",
106
+ # "Open Access PDF": paper.get("openAccessPdf", {}).get("url", "N/A")
107
+ # if paper.get("openAccessPdf") is not None
108
+ # else "N/A",
109
+ }
110
+ for paper in recommendations
111
+ if paper.get("title") and paper.get("authors")
112
+ }
113
+
114
+ # Create a DataFrame for pretty printing
115
+ df = pd.DataFrame(filtered_papers)
116
+
117
+ # Format papers for state update
118
+ papers = [
119
+ f"Paper ID: {paper_id}\n"
120
+ f"Title: {paper_data['Title']}\n"
121
+ f"Abstract: {paper_data['Abstract']}\n"
122
+ f"Year: {paper_data['Year']}\n"
123
+ f"Citations: {paper_data['Citation Count']}\n"
124
+ f"URL: {paper_data['URL']}\n"
125
+ # f"Publication Type: {paper_data['Publication Type']}\n"
126
+ # f"Open Access PDF: {paper_data['Open Access PDF']}"
127
+ for paper_id, paper_data in filtered_papers.items()
128
+ ]
129
+
130
+ # Convert DataFrame to markdown table
131
+ markdown_table = df.to_markdown(tablefmt="grid")
132
+ logging.info("Search results: %s", papers)
133
+
134
+ return Command(
135
+ update={
136
+ "papers": filtered_papers, # Now sending the dictionary directly
137
+ "messages": [
138
+ ToolMessage(content=markdown_table, tool_call_id=tool_call_id)
139
+ ],
140
+ }
141
+ )
@@ -1,4 +1,5 @@
1
1
  '''
2
- This file is used to import the datasets, utils, and tools.
2
+ This file is used to import the datasets and utils.
3
3
  '''
4
4
  from . import datasets
5
+ from . import utils
File without changes
@@ -0,0 +1,242 @@
1
+ """
2
+ Test cases for datasets/primekg_loader.py
3
+ """
4
+
5
+ import os
6
+ import shutil
7
+ import pytest
8
+ from ..datasets.biobridge_primekg import BioBridgePrimeKG
9
+
10
+ # Remove the data folder for testing if it exists
11
+ PRIMEKG_LOCAL_DIR = "../data/primekg_test/"
12
+ LOCAL_DIR = "../data/biobridge_primekg_test/"
13
+ shutil.rmtree(LOCAL_DIR, ignore_errors=True)
14
+
15
+ @pytest.fixture(name="biobridge_primekg")
16
+ def biobridge_primekg_fixture():
17
+ """
18
+ Fixture for creating an instance of PrimeKG.
19
+ """
20
+ return BioBridgePrimeKG(primekg_dir=PRIMEKG_LOCAL_DIR,
21
+ local_dir=LOCAL_DIR)
22
+
23
+ def test_download_primekg(biobridge_primekg):
24
+ """
25
+ Test the loading method of the BioBridge-PrimeKG class by downloading data from repository.
26
+ """
27
+ # Load BioBridge-PrimeKG data
28
+ biobridge_primekg.load_data()
29
+ primekg_nodes = biobridge_primekg.get_primekg().get_nodes()
30
+ primekg_edges = biobridge_primekg.get_primekg().get_edges()
31
+ biobridge_data_config = biobridge_primekg.get_data_config()
32
+ biobridge_emb_dict = biobridge_primekg.get_node_embeddings()
33
+ biobridge_triplets = biobridge_primekg.get_primekg_triplets()
34
+ biobridge_splits = biobridge_primekg.get_train_test_split()
35
+ biobridge_node_info = biobridge_primekg.get_node_info_dict()
36
+
37
+ # Check if the local directories exists
38
+ assert os.path.exists(biobridge_primekg.primekg_dir)
39
+ assert os.path.exists(biobridge_primekg.local_dir)
40
+ # Check if downloaded and processed files exist
41
+ # PrimeKG files
42
+ files = ["nodes.tab", "primekg_nodes.tsv.gz",
43
+ "edges.csv", "primekg_edges.tsv.gz"]
44
+ for file in files:
45
+ path = f"{biobridge_primekg.primekg_dir}/{file}"
46
+ assert os.path.exists(path)
47
+ # BioBridge data config
48
+ assert os.path.exists(f"{biobridge_primekg.local_dir}/data_config.json")
49
+ # BioBridge embeddings
50
+ files = [
51
+ "protein.pkl",
52
+ "mf.pkl",
53
+ "cc.pkl",
54
+ "bp.pkl",
55
+ "drug.pkl",
56
+ "disease.pkl",
57
+ "embedding_dict.pkl"
58
+ ]
59
+ for file in files:
60
+ path = f"{biobridge_primekg.local_dir}/embeddings/{file}"
61
+ assert os.path.exists(path)
62
+ # BioBridge processed files
63
+ files = [
64
+ "protein.csv",
65
+ "mf.csv",
66
+ "cc.csv",
67
+ "bp.csv",
68
+ "drug.csv",
69
+ "disease.csv",
70
+ "triplet_full.tsv.gz",
71
+ "triplet_full_altered.tsv.gz",
72
+ "node_train.tsv.gz",
73
+ "triplet_train.tsv.gz",
74
+ "node_test.tsv.gz",
75
+ "triplet_test.tsv.gz",
76
+ ]
77
+ for file in files:
78
+ path = f"{biobridge_primekg.local_dir}/processed/{file}"
79
+ assert os.path.exists(path)
80
+ # Check processed PrimeKG dataframes
81
+ # Nodes
82
+ assert primekg_nodes is not None
83
+ assert len(primekg_nodes) > 0
84
+ assert primekg_nodes.shape[0] == 129375
85
+ # Edges
86
+ assert primekg_edges is not None
87
+ assert len(primekg_edges) > 0
88
+ assert primekg_edges.shape[0] == 8100498
89
+ # Check processed BioBridge data config
90
+ assert biobridge_data_config is not None
91
+ assert len(biobridge_data_config) > 0
92
+ assert len(biobridge_data_config['node_type']) == 10
93
+ assert len(biobridge_data_config['relation_type']) == 18
94
+ assert len(biobridge_data_config['emb_dim']) == 6
95
+ # Check processed BioBridge embeddings
96
+ assert biobridge_emb_dict is not None
97
+ assert len(biobridge_emb_dict) > 0
98
+ assert len(biobridge_emb_dict) == 85466
99
+ # Check processed BioBridge triplets
100
+ assert biobridge_triplets is not None
101
+ assert len(biobridge_triplets) > 0
102
+ assert biobridge_triplets.shape[0] == 3904610
103
+ assert list(biobridge_splits.keys()) == ['train', 'node_train', 'test', 'node_test']
104
+ assert len(biobridge_splits['train']) == 3510930
105
+ assert len(biobridge_splits['node_train']) == 76486
106
+ assert len(biobridge_splits['test']) == 393680
107
+ assert len(biobridge_splits['node_test']) == 8495
108
+ # Check node info dictionary
109
+ assert list(biobridge_node_info.keys()) == ['gene/protein',
110
+ 'molecular_function',
111
+ 'cellular_component',
112
+ 'biological_process',
113
+ 'drug',
114
+ 'disease']
115
+ assert len(biobridge_node_info['gene/protein']) == 19162
116
+ assert len(biobridge_node_info['molecular_function']) == 10966
117
+ assert len(biobridge_node_info['cellular_component']) == 4013
118
+ assert len(biobridge_node_info['biological_process']) == 27478
119
+ assert len(biobridge_node_info['drug']) == 6948
120
+ assert len(biobridge_node_info['disease']) == 44133
121
+
122
+
123
+ def test_load_existing_primekg(biobridge_primekg):
124
+ """
125
+ Test the loading method of the BioBridge-PrimeKG class by loading existing data in local.
126
+ """
127
+ # Load BioBridge-PrimeKG data
128
+ biobridge_primekg.load_data()
129
+ primekg_nodes = biobridge_primekg.get_primekg().get_nodes()
130
+ primekg_edges = biobridge_primekg.get_primekg().get_edges()
131
+ biobridge_data_config = biobridge_primekg.get_data_config()
132
+ biobridge_emb_dict = biobridge_primekg.get_node_embeddings()
133
+ biobridge_triplets = biobridge_primekg.get_primekg_triplets()
134
+ biobridge_splits = biobridge_primekg.get_train_test_split()
135
+ biobridge_node_info = biobridge_primekg.get_node_info_dict()
136
+
137
+ # Check if the local directories exists
138
+ assert os.path.exists(biobridge_primekg.primekg_dir)
139
+ assert os.path.exists(biobridge_primekg.local_dir)
140
+ # Check if downloaded and processed files exist
141
+ # PrimeKG files
142
+ files = ["nodes.tab", "primekg_nodes.tsv.gz",
143
+ "edges.csv", "primekg_edges.tsv.gz"]
144
+ for file in files:
145
+ path = f"{biobridge_primekg.primekg_dir}/{file}"
146
+ assert os.path.exists(path)
147
+ # BioBridge data config
148
+ assert os.path.exists(f"{biobridge_primekg.local_dir}/data_config.json")
149
+ # BioBridge embeddings
150
+ files = [
151
+ "protein.pkl",
152
+ "mf.pkl",
153
+ "cc.pkl",
154
+ "bp.pkl",
155
+ "drug.pkl",
156
+ "disease.pkl",
157
+ "embedding_dict.pkl"
158
+ ]
159
+ for file in files:
160
+ path = f"{biobridge_primekg.local_dir}/embeddings/{file}"
161
+ assert os.path.exists(path)
162
+ # BioBridge processed files
163
+ files = [
164
+ "protein.csv",
165
+ "mf.csv",
166
+ "cc.csv",
167
+ "bp.csv",
168
+ "drug.csv",
169
+ "disease.csv",
170
+ "triplet_full.tsv.gz",
171
+ "triplet_full_altered.tsv.gz",
172
+ "node_train.tsv.gz",
173
+ "triplet_train.tsv.gz",
174
+ "node_test.tsv.gz",
175
+ "triplet_test.tsv.gz",
176
+ ]
177
+ for file in files:
178
+ path = f"{biobridge_primekg.local_dir}/processed/{file}"
179
+ assert os.path.exists(path)
180
+ # Check processed PrimeKG dataframes
181
+ # Nodes
182
+ assert primekg_nodes is not None
183
+ assert len(primekg_nodes) > 0
184
+ assert primekg_nodes.shape[0] == 129375
185
+ # Edges
186
+ assert primekg_edges is not None
187
+ assert len(primekg_edges) > 0
188
+ assert primekg_edges.shape[0] == 8100498
189
+ # Check processed BioBridge data config
190
+ assert biobridge_data_config is not None
191
+ assert len(biobridge_data_config) > 0
192
+ assert len(biobridge_data_config['node_type']) == 10
193
+ assert len(biobridge_data_config['relation_type']) == 18
194
+ assert len(biobridge_data_config['emb_dim']) == 6
195
+ # Check processed BioBridge embeddings
196
+ assert biobridge_emb_dict is not None
197
+ assert len(biobridge_emb_dict) > 0
198
+ assert len(biobridge_emb_dict) == 85466
199
+ # Check processed BioBridge triplets
200
+ assert biobridge_triplets is not None
201
+ assert len(biobridge_triplets) > 0
202
+ assert biobridge_triplets.shape[0] == 3904610
203
+ assert list(biobridge_splits.keys()) == ['train', 'node_train', 'test', 'node_test']
204
+ assert len(biobridge_splits['train']) == 3510930
205
+ assert len(biobridge_splits['node_train']) == 76486
206
+ assert len(biobridge_splits['test']) == 393680
207
+ assert len(biobridge_splits['node_test']) == 8495
208
+ # Check node info dictionary
209
+ assert list(biobridge_node_info.keys()) == ['gene/protein',
210
+ 'molecular_function',
211
+ 'cellular_component',
212
+ 'biological_process',
213
+ 'drug',
214
+ 'disease']
215
+ assert len(biobridge_node_info['gene/protein']) == 19162
216
+ assert len(biobridge_node_info['molecular_function']) == 10966
217
+ assert len(biobridge_node_info['cellular_component']) == 4013
218
+ assert len(biobridge_node_info['biological_process']) == 27478
219
+ assert len(biobridge_node_info['drug']) == 6948
220
+ assert len(biobridge_node_info['disease']) == 44133
221
+
222
+ # def test_load_existing_primekg_with_negative_triplets(biobridge_primekg):
223
+ # """
224
+ # Test the loading method of the BioBridge-PrimeKG class by loading existing data in local.
225
+ # In addition, it builds negative triplets for training data.
226
+ # """
227
+ # # Load BioBridge-PrimeKG data
228
+ # # Using 1 negative sample per positive triplet
229
+ # biobridge_primekg.load_data(build_neg_triplest=True, n_neg_samples=1)
230
+ # biobridge_neg_triplets = biobridge_primekg.get_primekg_triplets_negative()
231
+
232
+ # # Check if the local directories exists
233
+ # assert os.path.exists(biobridge_primekg.primekg_dir)
234
+ # assert os.path.exists(biobridge_primekg.local_dir)
235
+ # # Check if downloaded and processed files exist
236
+ # path = f"{biobridge_primekg.local_dir}/processed/triplet_train_negative.tsv.gz"
237
+ # assert os.path.exists(path)
238
+ # # Check processed BioBridge triplets
239
+ # assert biobridge_neg_triplets is not None
240
+ # assert len(biobridge_neg_triplets) > 0
241
+ # assert biobridge_neg_triplets.shape[0] == 3510930
242
+ # assert len(biobridge_neg_triplets.negative_tail_index[0]) == 1
@@ -0,0 +1,29 @@
1
+ """
2
+ Test cases for datasets/dataset.py
3
+ """
4
+
5
+ from ..datasets.dataset import Dataset
6
+
7
+ class MockDataset(Dataset):
8
+ """
9
+ Mock dataset class for testing purposes.
10
+ """
11
+ def setup(self):
12
+ pass
13
+
14
+ def load_data(self):
15
+ pass
16
+
17
+ def test_dataset_setup():
18
+ """
19
+ Test the setup method of the Dataset class.
20
+ """
21
+ dataset = MockDataset()
22
+ assert dataset.setup() is None
23
+
24
+ def test_dataset_load_data():
25
+ """
26
+ Test the load_data method of the Dataset class.
27
+ """
28
+ dataset = MockDataset()
29
+ assert dataset.load_data() is None
@@ -0,0 +1,73 @@
1
+ """
2
+ Test cases for datasets/primekg_loader.py
3
+ """
4
+
5
+ import os
6
+ import shutil
7
+ import pytest
8
+ from ..datasets.primekg import PrimeKG
9
+
10
+ # Remove the data folder for testing if it exists
11
+ LOCAL_DIR = "../data/primekg_test/"
12
+ shutil.rmtree(LOCAL_DIR, ignore_errors=True)
13
+
14
+ @pytest.fixture(name="primekg")
15
+ def primekg_fixture():
16
+ """
17
+ Fixture for creating an instance of PrimeKG.
18
+ """
19
+ return PrimeKG(local_dir=LOCAL_DIR)
20
+
21
+ def test_download_primekg(primekg):
22
+ """
23
+ Test the loading method of the PrimeKG class by downloading PrimeKG from server.
24
+ """
25
+ # Load PrimeKG data
26
+ primekg.load_data()
27
+ primekg_nodes = primekg.get_nodes()
28
+ primekg_edges = primekg.get_edges()
29
+
30
+ # Check if the local directory exists
31
+ assert os.path.exists(primekg.local_dir)
32
+ # Check if downloaded and processed files exist
33
+ files = ["nodes.tab", f"{primekg.name}_nodes.tsv.gz",
34
+ "edges.csv", f"{primekg.name}_edges.tsv.gz"]
35
+ for file in files:
36
+ path = f"{primekg.local_dir}/{file}"
37
+ assert os.path.exists(path)
38
+ # Check processed PrimeKG dataframes
39
+ # Nodes
40
+ assert primekg_nodes is not None
41
+ assert len(primekg_nodes) > 0
42
+ assert primekg_nodes.shape[0] == 129375
43
+ # Edges
44
+ assert primekg_edges is not None
45
+ assert len(primekg_edges) > 0
46
+ assert primekg_edges.shape[0] == 8100498
47
+
48
+ def test_load_existing_primekg(primekg):
49
+ """
50
+ Test the loading method of the PrimeKG class by loading existing PrimeKG in local.
51
+ """
52
+ # Load PrimeKG data
53
+ primekg.load_data()
54
+ primekg_nodes = primekg.get_nodes()
55
+ primekg_edges = primekg.get_edges()
56
+
57
+ # Check if the local directory exists
58
+ assert os.path.exists(primekg.local_dir)
59
+ # Check if downloaded and processed files exist
60
+ files = ["nodes.tab", f"{primekg.name}_nodes.tsv.gz",
61
+ "edges.csv", f"{primekg.name}_edges.tsv.gz"]
62
+ for file in files:
63
+ path = f"{primekg.local_dir}/{file}"
64
+ assert os.path.exists(path)
65
+ # Check processed PrimeKG dataframes
66
+ # Nodes
67
+ assert primekg_nodes is not None
68
+ assert len(primekg_nodes) > 0
69
+ assert primekg_nodes.shape[0] == 129375
70
+ # Edges
71
+ assert primekg_edges is not None
72
+ assert len(primekg_edges) > 0
73
+ assert primekg_edges.shape[0] == 8100498
@@ -0,0 +1,116 @@
1
+ """
2
+ Test cases for datasets/starkqa_primekg_loader.py
3
+ """
4
+
5
+ import os
6
+ import shutil
7
+ import pytest
8
+ from ..datasets.starkqa_primekg import StarkQAPrimeKG
9
+
10
+ # Remove the data folder for testing if it exists
11
+ LOCAL_DIR = "../data/starkqa_primekg_test/"
12
+ shutil.rmtree(LOCAL_DIR, ignore_errors=True)
13
+
14
+ @pytest.fixture(name="starkqa_primekg")
15
+ def starkqa_primekg_fixture():
16
+ """
17
+ Fixture for creating an instance of StarkQAPrimeKGData.
18
+ """
19
+ return StarkQAPrimeKG(local_dir=LOCAL_DIR)
20
+
21
+ def test_download_starkqa_primekg(starkqa_primekg):
22
+ """
23
+ Test the loading method of the StarkQAPrimeKGLoaderTool class by downloading files
24
+ from HuggingFace Hub.
25
+ """
26
+ # Load StarkQA PrimeKG data
27
+ starkqa_primekg.load_data()
28
+ starkqa_df = starkqa_primekg.get_starkqa()
29
+ primekg_node_info = starkqa_primekg.get_starkqa_node_info()
30
+ split_idx = starkqa_primekg.get_starkqa_split_indicies()
31
+ query_embeddings = starkqa_primekg.get_query_embeddings()
32
+ node_embeddings = starkqa_primekg.get_node_embeddings()
33
+
34
+ # Check if the local directory exists
35
+ assert os.path.exists(starkqa_primekg.local_dir)
36
+ # Check if downloaded files exist in the local directory
37
+ files = ['qa/prime/split/test-0.1.index',
38
+ 'qa/prime/split/test.index',
39
+ 'qa/prime/split/train.index',
40
+ 'qa/prime/split/val.index',
41
+ 'qa/prime/stark_qa/stark_qa.csv',
42
+ 'qa/prime/stark_qa/stark_qa_human_generated_eval.csv',
43
+ 'skb/prime/processed.zip']
44
+ for file in files:
45
+ path = f"{starkqa_primekg.local_dir}/{file}"
46
+ assert os.path.exists(path)
47
+ # Check dataframe
48
+ assert starkqa_df is not None
49
+ assert len(starkqa_df) > 0
50
+ assert starkqa_df.shape[0] == 11204
51
+ # Check node information
52
+ assert primekg_node_info is not None
53
+ assert len(primekg_node_info) == 129375
54
+ # Check split indices
55
+ assert list(split_idx.keys()) == ['train', 'val', 'test', 'test-0.1']
56
+ assert len(split_idx['train']) == 6162
57
+ assert len(split_idx['val']) == 2241
58
+ assert len(split_idx['test']) == 2801
59
+ assert len(split_idx['test-0.1']) == 280
60
+ # Check query embeddings
61
+ assert query_embeddings is not None
62
+ assert len(query_embeddings) == 11204
63
+ assert query_embeddings[0].shape[1] == 1536
64
+ # Check node embeddings
65
+ assert node_embeddings is not None
66
+ assert len(node_embeddings) == 129375
67
+ assert node_embeddings[0].shape[1] == 1536
68
+
69
+ def test_load_existing_starkqa_primekg(starkqa_primekg):
70
+ """
71
+
72
+ Test the loading method of the StarkQAPrimeKGLoaderTool class by loading existing files
73
+ in the local directory.
74
+ """
75
+ # Load StarkQA PrimeKG data
76
+ starkqa_primekg.load_data()
77
+ starkqa_df = starkqa_primekg.get_starkqa()
78
+ primekg_node_info = starkqa_primekg.get_starkqa_node_info()
79
+ split_idx = starkqa_primekg.get_starkqa_split_indicies()
80
+ query_embeddings = starkqa_primekg.get_query_embeddings()
81
+ node_embeddings = starkqa_primekg.get_node_embeddings()
82
+
83
+ # Check if the local directory exists
84
+ assert os.path.exists(starkqa_primekg.local_dir)
85
+ # Check if downloaded and processed files exist
86
+ files = ['qa/prime/split/test-0.1.index',
87
+ 'qa/prime/split/test.index',
88
+ 'qa/prime/split/train.index',
89
+ 'qa/prime/split/val.index',
90
+ 'qa/prime/stark_qa/stark_qa.csv',
91
+ 'qa/prime/stark_qa/stark_qa_human_generated_eval.csv',
92
+ 'skb/prime/processed.zip']
93
+ for file in files:
94
+ path = f"{starkqa_primekg.local_dir}/{file}"
95
+ assert os.path.exists(path)
96
+ # Check dataframe
97
+ assert starkqa_df is not None
98
+ assert len(starkqa_df) > 0
99
+ assert starkqa_df.shape[0] == 11204
100
+ # Check node information
101
+ assert primekg_node_info is not None
102
+ assert len(primekg_node_info) == 129375
103
+ # Check split indices
104
+ assert list(split_idx.keys()) == ['train', 'val', 'test', 'test-0.1']
105
+ assert len(split_idx['train']) == 6162
106
+ assert len(split_idx['val']) == 2241
107
+ assert len(split_idx['test']) == 2801
108
+ assert len(split_idx['test-0.1']) == 280
109
+ # Check query embeddings
110
+ assert query_embeddings is not None
111
+ assert len(query_embeddings) == 11204
112
+ assert query_embeddings[0].shape[1] == 1536
113
+ # Check node embeddings
114
+ assert node_embeddings is not None
115
+ assert len(node_embeddings) == 129375
116
+ assert node_embeddings[0].shape[1] == 1536
@@ -0,0 +1,47 @@
1
+ """
2
+ Test cases for utils/embeddings/embeddings.py
3
+ """
4
+
5
+ import pytest
6
+ from ..utils.embeddings.embeddings import Embeddings
7
+
8
+ class TestEmbeddings(Embeddings):
9
+ """Test implementation of the Embeddings interface for testing purposes."""
10
+
11
+ def embed_documents(self, texts: list[str]) -> list[list[float]]:
12
+ return [[0.1, 0.2, 0.3] for _ in texts]
13
+
14
+ def embed_query(self, text: str) -> list[float]:
15
+ return [0.1, 0.2, 0.3]
16
+
17
+ def test_embed_documents():
18
+ """Test embedding documents using the Embeddings interface."""
19
+ embeddings = TestEmbeddings()
20
+ texts = ["text1", "text2"]
21
+ result = embeddings.embed_documents(texts)
22
+ assert result == [[0.1, 0.2, 0.3], [0.1, 0.2, 0.3]]
23
+
24
+
25
+ def test_embed_query():
26
+ """Test embedding a query using the Embeddings interface."""
27
+ embeddings = TestEmbeddings()
28
+ text = "query"
29
+ result = embeddings.embed_query(text)
30
+ assert result == [0.1, 0.2, 0.3]
31
+
32
+ @pytest.mark.asyncio
33
+ async def test_aembed_documents():
34
+ """Test asynchronous embedding of documents using the Embeddings interface."""
35
+ embeddings = TestEmbeddings()
36
+ texts = ["text1", "text2"]
37
+ result = await embeddings.aembed_documents(texts)
38
+ assert result == [[0.1, 0.2, 0.3], [0.1, 0.2, 0.3]]
39
+
40
+
41
+ @pytest.mark.asyncio
42
+ async def test_aembed_query():
43
+ """Test asynchronous embedding of a query using the Embeddings interface."""
44
+ embeddings = TestEmbeddings()
45
+ text = "query"
46
+ result = await embeddings.aembed_query(text)
47
+ assert result == [0.1, 0.2, 0.3]