aiagents4pharma 1.17.1__py3-none-any.whl → 1.19.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (56) hide show
  1. aiagents4pharma/talk2biomodels/agents/t2b_agent.py +4 -4
  2. aiagents4pharma/talk2biomodels/configs/tools/ask_question/default.yaml +7 -15
  3. aiagents4pharma/talk2biomodels/states/state_talk2biomodels.py +4 -1
  4. aiagents4pharma/talk2biomodels/tests/test_ask_question.py +4 -2
  5. aiagents4pharma/talk2biomodels/tests/test_get_annotation.py +4 -2
  6. aiagents4pharma/talk2biomodels/tests/test_integration.py +34 -30
  7. aiagents4pharma/talk2biomodels/tests/test_query_article.py +7 -1
  8. aiagents4pharma/talk2biomodels/tests/test_search_models.py +3 -1
  9. aiagents4pharma/talk2biomodels/tests/test_steady_state.py +6 -3
  10. aiagents4pharma/talk2biomodels/tools/ask_question.py +1 -2
  11. aiagents4pharma/talk2biomodels/tools/custom_plotter.py +23 -10
  12. aiagents4pharma/talk2biomodels/tools/get_annotation.py +11 -10
  13. aiagents4pharma/talk2biomodels/tools/query_article.py +6 -2
  14. aiagents4pharma/talk2biomodels/tools/search_models.py +8 -2
  15. aiagents4pharma/talk2knowledgegraphs/__init__.py +3 -0
  16. aiagents4pharma/talk2knowledgegraphs/agents/__init__.py +4 -0
  17. aiagents4pharma/talk2knowledgegraphs/agents/t2kg_agent.py +85 -0
  18. aiagents4pharma/talk2knowledgegraphs/configs/__init__.py +7 -0
  19. aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/__init__.py +3 -0
  20. aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/default.yaml +62 -0
  21. aiagents4pharma/talk2knowledgegraphs/configs/app/__init__.py +4 -0
  22. aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/__init__.py +3 -0
  23. aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/default.yaml +31 -0
  24. aiagents4pharma/talk2knowledgegraphs/configs/config.yaml +7 -0
  25. aiagents4pharma/talk2knowledgegraphs/configs/tools/__init__.py +6 -0
  26. aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/__init__.py +3 -0
  27. aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/default.yaml +24 -0
  28. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/__init__.py +3 -0
  29. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/default.yaml +43 -0
  30. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/__init__.py +3 -0
  31. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/default.yaml +9 -0
  32. aiagents4pharma/talk2knowledgegraphs/states/__init__.py +4 -0
  33. aiagents4pharma/talk2knowledgegraphs/states/state_talk2knowledgegraphs.py +38 -0
  34. aiagents4pharma/talk2knowledgegraphs/tests/test_agents_t2kg_agent.py +110 -0
  35. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_graphrag_reasoning.py +210 -0
  36. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_extraction.py +174 -0
  37. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_summarization.py +154 -0
  38. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_huggingface.py +0 -1
  39. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_ollama.py +56 -0
  40. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ollama.py +18 -42
  41. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_kg_utils.py +79 -0
  42. aiagents4pharma/talk2knowledgegraphs/tools/__init__.py +6 -0
  43. aiagents4pharma/talk2knowledgegraphs/tools/graphrag_reasoning.py +143 -0
  44. aiagents4pharma/talk2knowledgegraphs/tools/load_arguments.py +22 -0
  45. aiagents4pharma/talk2knowledgegraphs/tools/subgraph_extraction.py +305 -0
  46. aiagents4pharma/talk2knowledgegraphs/tools/subgraph_summarization.py +126 -0
  47. aiagents4pharma/talk2knowledgegraphs/utils/__init__.py +4 -2
  48. aiagents4pharma/talk2knowledgegraphs/utils/embeddings/__init__.py +1 -0
  49. aiagents4pharma/talk2knowledgegraphs/utils/embeddings/ollama.py +81 -0
  50. aiagents4pharma/talk2knowledgegraphs/utils/extractions/__init__.py +4 -0
  51. aiagents4pharma/talk2knowledgegraphs/utils/extractions/pcst.py +225 -0
  52. {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/METADATA +12 -3
  53. {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/RECORD +56 -24
  54. {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/LICENSE +0 -0
  55. {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/WHEEL +0 -0
  56. {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,174 @@
1
+ """
2
+ Test cases for tools/subgraph_extraction.py
3
+ """
4
+
5
+ import pytest
6
+ from langchain_core.messages import HumanMessage
7
+ from langchain_openai import ChatOpenAI, OpenAIEmbeddings
8
+ from ..agents.t2kg_agent import get_app
9
+
10
+ # Define the data path
11
+ DATA_PATH = "aiagents4pharma/talk2knowledgegraphs/tests/files"
12
+
13
+
14
+ @pytest.fixture(name="input_dict")
15
+ def input_dict_fixture():
16
+ """
17
+ Input dictionary fixture.
18
+ """
19
+ input_dict = {
20
+ "llm_model": None, # TBA for each test case
21
+ "embedding_model": None, # TBA for each test case
22
+ "uploaded_files": [],
23
+ "topk_nodes": 3,
24
+ "topk_edges": 3,
25
+ "dic_source_graph": [
26
+ {
27
+ "name": "PrimeKG",
28
+ "kg_pyg_path": f"{DATA_PATH}/primekg_ibd_pyg_graph.pkl",
29
+ "kg_text_path": f"{DATA_PATH}/primekg_ibd_text_graph.pkl",
30
+ }
31
+ ],
32
+ }
33
+
34
+ return input_dict
35
+
36
+
37
+ def test_extract_subgraph_wo_docs(input_dict):
38
+ """
39
+ Test the subgraph extraction tool without any documents using OpenAI model.
40
+
41
+ Args:
42
+ input_dict: Input dictionary.
43
+ """
44
+ # Prepare LLM and embedding model
45
+ input_dict["llm_model"] = ChatOpenAI(model="gpt-4o-mini", temperature=0.0)
46
+ input_dict["embedding_model"] = OpenAIEmbeddings(model="text-embedding-3-small")
47
+
48
+ # Setup the app
49
+ unique_id = 12345
50
+ app = get_app(unique_id, llm_model=input_dict["llm_model"])
51
+ config = {"configurable": {"thread_id": unique_id}}
52
+ # Update state
53
+ app.update_state(
54
+ config,
55
+ input_dict,
56
+ )
57
+ prompt = """
58
+ Please directly invoke `subgraph_extraction` tool without calling any other tools
59
+ to respond to the following prompt:
60
+
61
+ Extract all relevant information related to nodes of genes related to inflammatory bowel disease
62
+ (IBD) that existed in the knowledge graph.
63
+ Please set the extraction name for this process as `subkg_12345`.
64
+ """
65
+
66
+ # Test the tool subgraph_extraction
67
+ response = app.invoke({"messages": [HumanMessage(content=prompt)]}, config=config)
68
+
69
+ # Check assistant message
70
+ assistant_msg = response["messages"][-1].content
71
+ assert isinstance(assistant_msg, str)
72
+
73
+ # Check tool message
74
+ tool_msg = response["messages"][-2]
75
+ assert tool_msg.name == "subgraph_extraction"
76
+
77
+ # Check extracted subgraph dictionary
78
+ current_state = app.get_state(config)
79
+ dic_extracted_graph = current_state.values["dic_extracted_graph"][0]
80
+ assert isinstance(dic_extracted_graph, dict)
81
+ assert dic_extracted_graph["name"] == "subkg_12345"
82
+ assert dic_extracted_graph["graph_source"] == "PrimeKG"
83
+ assert dic_extracted_graph["topk_nodes"] == 3
84
+ assert dic_extracted_graph["topk_edges"] == 3
85
+ assert isinstance(dic_extracted_graph["graph_dict"], dict)
86
+ assert len(dic_extracted_graph["graph_dict"]["nodes"]) > 0
87
+ assert len(dic_extracted_graph["graph_dict"]["edges"]) > 0
88
+ assert isinstance(dic_extracted_graph["graph_text"], str)
89
+ # Check if the nodes are in the graph_text
90
+ assert all(
91
+ n[0] in dic_extracted_graph["graph_text"]
92
+ for n in dic_extracted_graph["graph_dict"]["nodes"]
93
+ )
94
+ # Check if the edges are in the graph_text
95
+ assert all(
96
+ ",".join([e[0], '"' + str(tuple(e[2]["relation"])) + '"', e[1]])
97
+ in dic_extracted_graph["graph_text"]
98
+ for e in dic_extracted_graph["graph_dict"]["edges"]
99
+ )
100
+
101
+
102
+ def test_extract_subgraph_w_docs(input_dict):
103
+ """
104
+ Test the subgraph extraction tool with a document as reference (i.e., endotype document)
105
+ using OpenAI model.
106
+
107
+ Args:
108
+ input_dict: Input dictionary.
109
+ """
110
+ # Prepare LLM and embedding model
111
+ input_dict["llm_model"] = ChatOpenAI(model="gpt-4o-mini", temperature=0.0)
112
+ input_dict["embedding_model"] = OpenAIEmbeddings(model="text-embedding-3-small")
113
+
114
+ # Setup the app
115
+ unique_id = 12345
116
+ app = get_app(unique_id, llm_model=input_dict["llm_model"])
117
+ config = {"configurable": {"thread_id": unique_id}}
118
+ # Update state
119
+ input_dict["uploaded_files"] = [
120
+ {
121
+ "file_name": "DGE_human_Colon_UC-vs-Colon_Control.pdf",
122
+ "file_path": f"{DATA_PATH}/DGE_human_Colon_UC-vs-Colon_Control.pdf",
123
+ "file_type": "endotype",
124
+ "uploaded_by": "VPEUser",
125
+ "uploaded_timestamp": "2024-11-05 00:00:00",
126
+ }
127
+ ]
128
+ app.update_state(
129
+ config,
130
+ input_dict,
131
+ )
132
+ prompt = """
133
+ Please ONLY invoke `subgraph_extraction` tool without calling any other tools
134
+ to respond to the following prompt:
135
+
136
+ Extract all relevant information related to nodes of genes related to inflammatory bowel disease
137
+ (IBD) that existed in the knowledge graph.
138
+ Please set the extraction name for this process as `subkg_12345`.
139
+ """
140
+
141
+ # Test the tool subgraph_extraction
142
+ response = app.invoke({"messages": [HumanMessage(content=prompt)]}, config=config)
143
+
144
+ # Check assistant message
145
+ assistant_msg = response["messages"][-1].content
146
+ assert isinstance(assistant_msg, str)
147
+
148
+ # Check tool message
149
+ tool_msg = response["messages"][-2]
150
+ assert tool_msg.name == "subgraph_extraction"
151
+
152
+ # Check extracted subgraph dictionary
153
+ current_state = app.get_state(config)
154
+ dic_extracted_graph = current_state.values["dic_extracted_graph"][0]
155
+ assert isinstance(dic_extracted_graph, dict)
156
+ assert dic_extracted_graph["name"] == "subkg_12345"
157
+ assert dic_extracted_graph["graph_source"] == "PrimeKG"
158
+ assert dic_extracted_graph["topk_nodes"] == 3
159
+ assert dic_extracted_graph["topk_edges"] == 3
160
+ assert isinstance(dic_extracted_graph["graph_dict"], dict)
161
+ assert len(dic_extracted_graph["graph_dict"]["nodes"]) > 0
162
+ assert len(dic_extracted_graph["graph_dict"]["edges"]) > 0
163
+ assert isinstance(dic_extracted_graph["graph_text"], str)
164
+ # Check if the nodes are in the graph_text
165
+ assert all(
166
+ n[0] in dic_extracted_graph["graph_text"]
167
+ for n in dic_extracted_graph["graph_dict"]["nodes"]
168
+ )
169
+ # Check if the edges are in the graph_text
170
+ assert all(
171
+ ",".join([e[0], '"' + str(tuple(e[2]["relation"])) + '"', e[1]])
172
+ in dic_extracted_graph["graph_text"]
173
+ for e in dic_extracted_graph["graph_dict"]["edges"]
174
+ )
@@ -0,0 +1,154 @@
1
+ """
2
+ Test cases for tools/subgraph_summarization.py
3
+ """
4
+
5
+ import pytest
6
+ from langchain_core.messages import HumanMessage
7
+ from langchain_openai import ChatOpenAI, OpenAIEmbeddings
8
+ from ..agents.t2kg_agent import get_app
9
+
10
+ # Define the data path
11
+ DATA_PATH = "aiagents4pharma/talk2knowledgegraphs/tests/files"
12
+
13
+
14
+ @pytest.fixture(name="input_dict")
15
+ def input_dict_fixture():
16
+ """
17
+ Input dictionary fixture.
18
+ """
19
+ input_dict = {
20
+ "llm_model": None, # TBA for each test case
21
+ "embedding_model": None, # TBA for each test case
22
+ "uploaded_files": [],
23
+ "topk_nodes": 3,
24
+ "topk_edges": 3,
25
+ "dic_source_graph": [
26
+ {
27
+ "name": "PrimeKG",
28
+ "kg_pyg_path": f"{DATA_PATH}/primekg_ibd_pyg_graph.pkl",
29
+ "kg_text_path": f"{DATA_PATH}/primekg_ibd_text_graph.pkl",
30
+ }
31
+ ],
32
+ "dic_extracted_graph": [
33
+ {
34
+ "name": "subkg_12345",
35
+ "tool_call_id": "tool_12345",
36
+ "graph_source": "PrimeKG",
37
+ "topk_nodes": 3,
38
+ "topk_edges": 3,
39
+ "graph_dict": {
40
+ 'nodes': [('IFNG_(3495)', {}),
41
+ ('IKBKG_(3672)', {}),
42
+ ('ATG16L1_(6661)', {}),
43
+ ('inflammatory bowel disease_(28158)', {}),
44
+ ('Crohn ileitis and jejunitis_(35814)', {}),
45
+ ("Crohn's colitis_(83770)", {})],
46
+ 'edges': [('IFNG_(3495)', 'inflammatory bowel disease_(28158)',
47
+ {'relation': ['gene/protein', 'associated with', 'disease'],
48
+ 'label': ['gene/protein', 'associated with', 'disease']}),
49
+ ('IFNG_(3495)', "Crohn's colitis_(83770)",
50
+ {'relation': ['gene/protein', 'associated with', 'disease'],
51
+ 'label': ['gene/protein', 'associated with', 'disease']}),
52
+ ('IFNG_(3495)', 'Crohn ileitis and jejunitis_(35814)',
53
+ {'relation': ['gene/protein', 'associated with', 'disease'],
54
+ 'label': ['gene/protein', 'associated with', 'disease']}),
55
+ ('ATG16L1_(6661)', 'IKBKG_(3672)',
56
+ {'relation': ['gene/protein', 'ppi', 'gene/protein'],
57
+ 'label': ['gene/protein', 'ppi', 'gene/protein']}),
58
+ ("Crohn's colitis_(83770)", 'ATG16L1_(6661)',
59
+ {'relation': ['disease', 'associated with', 'gene/protein'],
60
+ 'label': ['disease', 'associated with', 'gene/protein']})]},
61
+ "graph_text": """
62
+ node_id,node_attr
63
+ IFNG_(3495),"IFNG belongs to gene/protein category.
64
+ This gene encodes a soluble cytokine that is a member of the type II interferon class.
65
+ The encoded protein is secreted by cells of both the innate and adaptive immune systems.
66
+ The active protein is a homodimer that binds to the interferon gamma receptor
67
+ which triggers a cellular response to viral and microbial infections.
68
+ Mutations in this gene are associated with an increased susceptibility to viral,
69
+ bacterial and parasitic infections and to several autoimmune diseases.
70
+ [provided by RefSeq, Dec 2015]."
71
+ IKBKG_(3672),"IKBKG belongs to gene/protein category. This gene encodes the regulatory
72
+ subunit of the inhibitor of kappaB kinase (IKK) complex, which activates NF-kappaB
73
+ resulting in activation of genes involved in inflammation, immunity, cell survival,
74
+ and other pathways. Mutations in this gene result in incontinentia pigmenti,
75
+ hypohidrotic ectodermal dysplasia, and several other types of immunodeficiencies.
76
+ A pseudogene highly similar to this locus is located in an adjacent region of the
77
+ X chromosome. [provided by RefSeq, Mar 2016]."
78
+ ATG16L1_(6661),"ATG16L1 belongs to gene/protein category. The protein encoded
79
+ by this gene is part of a large protein complex that is necessary for autophagy,
80
+ the major process by which intracellular components are targeted to lysosomes
81
+ for degradation. Defects in this gene are a cause of susceptibility to inflammatory
82
+ bowel disease type 10 (IBD10). Several transcript variants encoding different
83
+ isoforms have been found for this gene.[provided by RefSeq, Jun 2010]."
84
+ inflammatory bowel disease_(28158),inflammatory bowel disease belongs to disease
85
+ category. Any inflammatory bowel disease in which the cause of the disease
86
+ is a mutation in the NOD2 gene.
87
+ Crohn ileitis and jejunitis_(35814),Crohn ileitis and jejunitis belongs to
88
+ disease category. An Crohn disease involving a pathogenic inflammatory
89
+ response in the ileum.
90
+ Crohn's colitis_(83770),Crohn's colitis belongs to disease category.
91
+ Crohn's disease affecting the colon.
92
+
93
+ head_id,edge_type,tail_id
94
+ Crohn's colitis_(83770),"('disease', 'associated with', 'gene/protein')",
95
+ ATG16L1_(6661)
96
+ ATG16L1_(6661),"('gene/protein', 'ppi', 'gene/protein')",IKBKG_(3672)
97
+ IFNG_(3495),"('gene/protein', 'associated with', 'disease')",
98
+ inflammatory bowel disease_(28158)
99
+ IFNG_(3495),"('gene/protein', 'associated with', 'disease')",Crohn's colitis_(83770)
100
+ IFNG_(3495),"('gene/protein', 'associated with', 'disease')",
101
+ Crohn ileitis and jejunitis_(35814)
102
+ """,
103
+ "graph_summary": None,
104
+ }
105
+ ],
106
+ }
107
+
108
+ return input_dict
109
+
110
+
111
+ def test_summarize_subgraph(input_dict):
112
+ """
113
+ Test the subgraph summarization tool without any documents using Ollama model.
114
+
115
+ Args:
116
+ input_dict: Input dictionary fixture.
117
+ """
118
+ # Prepare LLM and embedding model
119
+ input_dict["llm_model"] = ChatOpenAI(model="gpt-4o-mini", temperature=0.0)
120
+ input_dict["embedding_model"] = OpenAIEmbeddings(model="text-embedding-3-small")
121
+
122
+ # Setup the app
123
+ unique_id = 12345
124
+ app = get_app(unique_id, llm_model=input_dict["llm_model"])
125
+ config = {"configurable": {"thread_id": unique_id}}
126
+ # Update state
127
+ app.update_state(
128
+ config,
129
+ input_dict,
130
+ )
131
+ prompt = """
132
+ Please directly invoke `subgraph_summarization` tool without calling any other tools
133
+ to respond to the following prompt:
134
+
135
+ You are given a subgraph in the forms of textualized subgraph representing
136
+ nodes and edges (triples) obtained from extraction_name `subkg_12345`.
137
+ Summarize the given subgraph and higlight the importance nodes and edges.
138
+ """
139
+
140
+ # Test the tool subgraph_summarization
141
+ response = app.invoke({"messages": [HumanMessage(content=prompt)]}, config=config)
142
+
143
+ # Check assistant message
144
+ assistant_msg = response["messages"][-1].content
145
+ assert isinstance(assistant_msg, str)
146
+
147
+ # Check tool message
148
+ tool_msg = response["messages"][-2]
149
+ assert tool_msg.name == "subgraph_summarization"
150
+
151
+ # Check summarized subgraph
152
+ current_state = app.get_state(config)
153
+ dic_extracted_graph = current_state.values["dic_extracted_graph"][0]
154
+ assert isinstance(dic_extracted_graph["graph_summary"], str)
@@ -31,7 +31,6 @@ def test_embedding_with_huggingface_embed_query(embedding_model):
31
31
  # Check the result
32
32
  assert len(result) == 768
33
33
 
34
-
35
34
  def test_embedding_with_huggingface_failed():
36
35
  """Test embedding documents using the EmbeddingWithHuggingFace class."""
37
36
  # Check if the model is available on HuggingFace Hub
@@ -0,0 +1,56 @@
1
+ """
2
+ Test cases for utils/embeddings/ollama.py
3
+ """
4
+
5
+ import pytest
6
+ import ollama
7
+ from ..utils.embeddings.ollama import EmbeddingWithOllama
8
+
9
+ @pytest.fixture(name="ollama_config")
10
+ def fixture_ollama_config():
11
+ """Return a dictionary with Ollama configuration."""
12
+ return {
13
+ "model_name": "all-minilm", # Choose a small model
14
+ }
15
+
16
+ def test_no_model_ollama(ollama_config):
17
+ """Test the case when the Ollama model is not available."""
18
+ cfg = ollama_config
19
+
20
+ # Delete the Ollama model
21
+ try:
22
+ ollama.delete(cfg["model_name"])
23
+ except ollama.ResponseError:
24
+ pass
25
+
26
+ # Check if the model is available
27
+ with pytest.raises(
28
+ ValueError, match=f"Error: Pulled {cfg["model_name"]} model and restarted Ollama server."
29
+ ):
30
+ EmbeddingWithOllama(model_name=cfg["model_name"])
31
+
32
+ @pytest.fixture(name="embedding_model")
33
+ def embedding_model_fixture(ollama_config):
34
+ """Return the configuration object for the Ollama embedding model and model object"""
35
+ cfg = ollama_config
36
+ return EmbeddingWithOllama(model_name=cfg["model_name"])
37
+
38
+ def test_embedding_with_ollama_embed_documents(embedding_model):
39
+ """Test embedding documents using the EmbeddingWithOllama class."""
40
+ # Perform embedding
41
+ texts = ["Adalimumab", "Infliximab", "Vedolizumab"]
42
+ result = embedding_model.embed_documents(texts)
43
+ # Check the result
44
+ assert len(result) == 3
45
+ assert len(result[0]) == 384
46
+
47
+ def test_embedding_with_ollama_embed_query(embedding_model):
48
+ """Test embedding a query using the EmbeddingWithOllama class."""
49
+ # Perform embedding
50
+ text = "Adalimumab"
51
+ result = embedding_model.embed_query(text)
52
+ # Check the result
53
+ assert len(result) == 384
54
+
55
+ # Delete the Ollama model so that it will not be cached afterward
56
+ ollama.delete(embedding_model.model_name)
@@ -6,20 +6,21 @@ import pytest
6
6
  import ollama
7
7
  from ..utils.enrichments.ollama import EnrichmentWithOllama
8
8
 
9
+
9
10
  @pytest.fixture(name="ollama_config")
10
11
  def fixture_ollama_config():
11
12
  """Return a dictionary with Ollama configuration."""
12
13
  return {
13
- "model_name": "smollm2:360m",
14
+ "model_name": "llama3.2:1b",
14
15
  "prompt_enrichment": """
15
- Given the input as a list of strings, please return the list of addditional information of
16
- each input terms using your prior knowledge.
16
+ Given the input as a list of strings, please return the list of addditional information
17
+ of each input terms using your prior knowledge.
17
18
 
18
19
  Example:
19
20
  Input: ['acetaminophen', 'aspirin']
20
- Ouput: ['acetaminophen is a medication used to treat pain and fever',
21
+ Ouput: ['acetaminophen is a medication used to treat pain and fever',
21
22
  'aspirin is a medication used to treat pain, fever, and inflammation']
22
-
23
+
23
24
  Do not include any pretext as the output, only the list of strings enriched.
24
25
 
25
26
  Input: {input}
@@ -28,10 +29,11 @@ def fixture_ollama_config():
28
29
  "streaming": False,
29
30
  }
30
31
 
32
+
31
33
  def test_no_model_ollama(ollama_config):
32
34
  """Test the case when the Ollama model is not available."""
33
35
  cfg = ollama_config
34
- cfg_model = "smollm2:135m" # Choose a small model
36
+ cfg_model = "smollm2:135m" # Choose a small model
35
37
 
36
38
  # Delete the Ollama model
37
39
  try:
@@ -41,7 +43,8 @@ def test_no_model_ollama(ollama_config):
41
43
 
42
44
  # Check if the model is available
43
45
  with pytest.raises(
44
- ValueError, match=f"Error: Pulled {cfg_model} model and restarted Ollama server."
46
+ ValueError,
47
+ match=f"Error: Pulled {cfg_model} model and restarted Ollama server.",
45
48
  ):
46
49
  EnrichmentWithOllama(
47
50
  model_name=cfg_model,
@@ -51,7 +54,8 @@ def test_no_model_ollama(ollama_config):
51
54
  )
52
55
  ollama.delete(cfg_model)
53
56
 
54
- def test_enrich_nodes_ollama(ollama_config):
57
+
58
+ def test_enrich_ollama(ollama_config):
55
59
  """Test the Ollama textual enrichment class for node enrichment."""
56
60
  # Prepare enrichment model
57
61
  cfg = ollama_config
@@ -63,37 +67,11 @@ def test_enrich_nodes_ollama(ollama_config):
63
67
  )
64
68
 
65
69
  # Perform enrichment for nodes
66
- nodes = ["Adalimumab", "Infliximab"]
70
+ nodes = ["acetaminophen"]
67
71
  enriched_nodes = enr_model.enrich_documents(nodes)
68
72
  # Check the enriched nodes
69
- assert len(enriched_nodes) == 2
70
- assert all(
71
- enriched_nodes[i] != nodes[i] for i in range(len(nodes))
72
- )
73
-
74
-
75
- def test_enrich_relations_ollama(ollama_config):
76
- """Test the Ollama textual enrichment class for relation enrichment."""
77
- # Prepare enrichment model
78
- cfg = ollama_config
79
- enr_model = EnrichmentWithOllama(
80
- model_name=cfg["model_name"],
81
- prompt_enrichment=cfg["prompt_enrichment"],
82
- temperature=cfg["temperature"],
83
- streaming=cfg["streaming"],
84
- )
85
- # Perform enrichment for relations
86
- relations = [
87
- "IL23R-gene causation disease-inflammatory bowel diseases",
88
- "NOD2-gene causation disease-inflammatory bowel diseases",
89
- ]
90
- enriched_relations = enr_model.enrich_documents(relations)
91
- # Check the enriched relations
92
- assert len(enriched_relations) == 2
93
- assert all(
94
- enriched_relations[i] != relations[i]
95
- for i in range(len(relations))
96
- )
73
+ assert len(enriched_nodes) == 1
74
+ assert all(enriched_nodes[i] != nodes[i] for i in range(len(nodes)))
97
75
 
98
76
 
99
77
  def test_enrich_ollama_rag(ollama_config):
@@ -107,11 +85,9 @@ def test_enrich_ollama_rag(ollama_config):
107
85
  streaming=cfg["streaming"],
108
86
  )
109
87
  # Perform enrichment for nodes
110
- nodes = ["Adalimumab", "Infliximab"]
88
+ nodes = ["acetaminophen"]
111
89
  docs = [r"\path\to\doc1", r"\path\to\doc2"]
112
90
  enriched_nodes = enr_model.enrich_documents_with_rag(nodes, docs)
113
91
  # Check the enriched nodes
114
- assert len(enriched_nodes) == 2
115
- assert all(
116
- enriched_nodes[i] != nodes[i] for i in range(len(nodes))
117
- )
92
+ assert len(enriched_nodes) == 1
93
+ assert all(enriched_nodes[i] != nodes[i] for i in range(len(nodes)))
@@ -0,0 +1,79 @@
1
+ """
2
+ Test cases for utils/kg_utils.py
3
+ """
4
+
5
+ import pytest
6
+ import networkx as nx
7
+ import pandas as pd
8
+ from ..utils import kg_utils
9
+
10
+
11
+ @pytest.fixture(name="sample_graph")
12
+ def make_sample_graph():
13
+ """Return a sample graph"""
14
+ sg = nx.Graph()
15
+ sg.add_node(1, node_id=1, feature_id="A", feature_value="ValueA")
16
+ sg.add_node(2, node_id=2, feature_id="B", feature_value="ValueB")
17
+ sg.add_edge(1, 2, edge_id=1, feature_id="E", feature_value="EdgeValue")
18
+ return sg
19
+
20
+
21
+ def test_kg_to_df_pandas(sample_graph):
22
+ """Test the kg_to_df_pandas function"""
23
+ df_nodes, df_edges = kg_utils.kg_to_df_pandas(sample_graph)
24
+ print(df_nodes)
25
+ expected_nodes_data = {
26
+ "node_id": [1, 2],
27
+ "feature_id": ["A", "B"],
28
+ "feature_value": ["ValueA", "ValueB"],
29
+ }
30
+ expected_nodes_df = pd.DataFrame(expected_nodes_data, index=[1, 2])
31
+ print(expected_nodes_df)
32
+ expected_edges_data = {
33
+ "node_source": [1],
34
+ "node_target": [2],
35
+ "edge_id": [1],
36
+ "feature_id": ["E"],
37
+ "feature_value": ["EdgeValue"],
38
+ }
39
+ expected_edges_df = pd.DataFrame(expected_edges_data)
40
+
41
+ # Assert that the dataframes are equal but the order of columns may be different
42
+ # Ignore the index of the dataframes
43
+ pd.testing.assert_frame_equal(df_nodes, expected_nodes_df, check_like=True)
44
+ pd.testing.assert_frame_equal(df_edges, expected_edges_df, check_like=True)
45
+
46
+
47
+ def test_df_pandas_to_kg():
48
+ """Test the df_pandas_to_kg function"""
49
+ nodes_data = {
50
+ "node_id": [1, 2],
51
+ "feature_id": ["A", "B"],
52
+ "feature_value": ["ValueA", "ValueB"],
53
+ }
54
+ df_nodes_attrs = pd.DataFrame(nodes_data).set_index("node_id")
55
+
56
+ edges_data = {
57
+ "node_source": [1],
58
+ "node_target": [2],
59
+ "edge_id": [1],
60
+ "feature_id": ["E"],
61
+ "feature_value": ["EdgeValue"],
62
+ }
63
+ df_edges = pd.DataFrame(edges_data)
64
+
65
+ kg = kg_utils.df_pandas_to_kg(
66
+ df_edges, df_nodes_attrs, "node_source", "node_target"
67
+ )
68
+
69
+ assert len(kg.nodes) == 2
70
+ assert len(kg.edges) == 1
71
+
72
+ assert kg.nodes[1]["feature_id"] == "A"
73
+ assert kg.nodes[1]["feature_value"] == "ValueA"
74
+ assert kg.nodes[2]["feature_id"] == "B"
75
+ assert kg.nodes[2]["feature_value"] == "ValueB"
76
+
77
+ assert kg.edges[1, 2]["feature_id"] == "E"
78
+ assert kg.edges[1, 2]["feature_value"] == "EdgeValue"
79
+ assert kg.edges[1, 2]["edge_id"] == 1
@@ -0,0 +1,6 @@
1
+ '''
2
+ This file is used to import all the models in the package.
3
+ '''
4
+ from . import subgraph_extraction
5
+ from . import subgraph_summarization
6
+ from . import graphrag_reasoning